torch_staintools.functional.optimization package
Submodules
torch_staintools.functional.optimization.dict_learning module
code directly adapted from https://github.com/rfeinman/pytorch-lasso
- torch_staintools.functional.optimization.dict_learning.dict_learning(x: Tensor, tissue_mask_flatten: Tensor | None, n_components: int, algorithm: Literal['ista', 'cd', 'fista'], *, alpha: float, lambd_ridge: float, steps: int, rng: Generator | None, init: str | None, lr: float | None, maxiter: int)
- torch_staintools.functional.optimization.dict_learning.dict_learning_loop(x: Tensor, z0: Tensor, weight: Tensor, alpha: Tensor, algorithm: Literal['ista', 'cd', 'fista'], *, lambd_ridge: float, steps: int, rng: Generator, init: str | None, lr: Tensor, maxiter: int)
- torch_staintools.functional.optimization.dict_learning.sparse_code(x: Tensor, weight: Tensor, alpha: Tensor, z0: Tensor, algorithm: Literal['ista', 'cd', 'fista'], lr: Tensor, maxiter: int, positive_code: bool)
- torch_staintools.functional.optimization.dict_learning.update_dict_cd(dictionary: Tensor, x: Tensor, code: Tensor, positive: bool = True, dead_thresh=1e-07, rng: Generator = None) Tuple[Tensor, Tensor]
Update the dictionary (stain matrix) using Block Coordinate Descent algorithm.
Can satisfy the positive constraint of dictionaries if specified. Side effects: code is updated inplace.
- Parameters:
dictionary – Tensor of shape (B, n_features, n_components). Value of the dictionary at the previous iteration. n_features in this context can be the number of input color channels. n_components in this context can be the number of stain to separate.
x – Tensor of shape (B, n_samples, n_features) Sparse coding of the data against which to optimize the dictionary.
code – Tensor of shape (B, n_samples, n_components) Sparse coding of the data against which to optimize the dictionary.
positive – Whether to enforce positivity when finding the dictionary.
dead_thresh – Minimum vector norm before considering “degenerate”
rng – torch.Generator for initialization of dictionary and code.
- Returns:
torch.Tensor, torch.Tensor, corresponding to the weight and the updated code.
- torch_staintools.functional.optimization.dict_learning.update_dict_ridge(x: Tensor, code: Tensor, lambd: float) Tuple[Tensor, Tensor]
Update an (unconstrained) dictionary with ridge regression
This is equivalent to a Newton step with the (L2-regularized) squared. May have severe numerical stability issues compared to update_dict_cd. error objective: f(V) = (1/2N) * ||Vz - x||_2^2 + (lambd/2) * ||V||_2^2
- Parameters:
x – a batch of observations with shape (B, n_samples, n_features)
code –
a batch of code vectors with shape (B, n_samples, n_components)
lambd – weight decay parameter
- Returns:
torch.Tensor, torch.Tensor, corresponding to the weight and the unmodified code.
torch_staintools.functional.optimization.linear_solver module
- torch_staintools.functional.optimization.linear_solver.lstsq_solver(od_flatten: Tensor, dictionary: Tensor, positive: bool)
Least squares solver for concentration.
Warning
for cusolver backend, algorithm = ‘ls’ May fail on GPU for individual large input (e.g., 1000 x 1000), regardless of batch size. To use ‘ls’ on large image, consider using magma backend:
`torch.backends.cuda.preferred_linalg_library('magma')`- Parameters:
od_flatten – B x num_pixels x C
dictionary – Transpose of stain matrix. B x C x num_stain
positive – enforce positive concentration
- Returns:
B x num_pixels x num_stains
- Return type:
concentration (flattened)
- torch_staintools.functional.optimization.linear_solver.pinv_solver(od_flatten: Tensor, dictionary: Tensor, positive: bool)
Pseudo-inverse solver for concentration.
- Parameters:
od_flatten – B x num_pixels x C
dictionary – Transpose of stain matrix. B x C x num_stains
positive – enforce positive concentration
- Returns:
B x num_pixels x num_stains
- Return type:
concentration (flattened)
- torch_staintools.functional.optimization.linear_solver.qr_solver(od_flatten, dictionary, positive: bool)
QR solver for concentration.
- Parameters:
od_flatten – B x num_pixels x C
dictionary – Transpose of stain matrix. B x C x num_stains
positive – enforce positive concentration
- Returns:
B x num_pixels x num_stains
- Return type:
concentration (flattened)
- torch_staintools.functional.optimization.linear_solver.qr_solver_generic(od_flatten, dictionary, positive: bool)
QR solver for concentration.
Warning
This is not scalable against the size of individual images compared to the 2-stain unrolled version
`qr_solver_two_stain`.- Parameters:
od_flatten – B x num_pixels x C
dictionary – Transpose of stain matrix. B x C x num_stains
positive – enforce positive concentration
- Returns:
B x num_pixels x num_stains
- Return type:
concentration (flattened)
- torch_staintools.functional.optimization.linear_solver.qr_solver_two_stain(od_flatten, dictionary, positive: bool)
QR solver for concentration (hardcoded 2-stain computation)
- Parameters:
od_flatten – B x num_pixels x C
dictionary – Transpose of stain matrix. B x C x num_stains
positive – enforce positive concentration
- Returns:
B x num_pixels x num_stains
- Return type:
concentration (flattened)
torch_staintools.functional.optimization.sparse_solver module
code directly adapted from https://github.com/rfeinman/pytorch-lasso
- torch_staintools.functional.optimization.sparse_solver.cd_loop(z: Tensor, b: Tensor, s: Tensor, alpha: Tensor, maxiter: int, positive_code: bool) Tensor
- torch_staintools.functional.optimization.sparse_solver.cd_step(z: Tensor, b: Tensor, s: Tensor, alpha: Tensor, positive_code: bool) tuple[Tensor, Tensor]
- torch_staintools.functional.optimization.sparse_solver.coord_descent(x: Tensor, z0: Tensor, weight: Tensor, alpha: Tensor, maxiter: int, positive_code: bool)
modified coord_descent
- torch_staintools.functional.optimization.sparse_solver.fista(x: Tensor, z0: Tensor, weight: Tensor, alpha: Tensor, lr: Tensor, maxiter: int, positive_code)
Fast ISTA solver
- Parameters:
x – data
z0 – code, or the initialization mode of the code.
weight – dict
alpha – penalty term for code
lr – learning rate/step size. If auto then it will be specified by the Lipschitz constant of f(z) = ||Wz - x||^2
maxiter – max number of iteration if not converge.
positive_code – whether enforce the positive z constraint
Returns:
- torch_staintools.functional.optimization.sparse_solver.fista_loop(z: Tensor, hessian: Tensor, b: Tensor, alpha: Tensor, lr: Tensor, maxiter: int, positive_code: bool = True) Tensor
FISTA Loop
- Parameters:
z – Initial guess
x (#) – Data input (OD space)
weight (#) – Dictionary matrix
hessian – precomputed wtw
b – precomputed xw
alpha – Regularization strength
lr – Learning rate
maxiter – Maximum iterations
positive_code
- torch_staintools.functional.optimization.sparse_solver.fista_step(z: Tensor, y: Tensor, t: Tensor, hessian: Tensor, b: Tensor, alpha: Tensor, lr: Tensor, positive_code: bool, keep: Tensor)
- torch_staintools.functional.optimization.sparse_solver.ista(x: Tensor, z0: Tensor, weight: Tensor, alpha: Tensor, lr: Tensor, maxiter: int, positive_code: bool)
ISTA solver
- Parameters:
x – data
z0 – code, or the initialization mode of the code.
weight – dict
alpha – penalty term for code
lr – learning rate/step size. If auto then it will be specified by the Lipschitz constant of f(z) = ||Wz - x||^2
maxiter – max number of iteration if not converge.
positive_code – whether enforce the positive z constraint
Returns:
- torch_staintools.functional.optimization.sparse_solver.ista_loop(z: Tensor, hessian: Tensor, b: Tensor, *, alpha: Tensor, lr: Tensor, maxiter: int, positive_code: bool)
- torch_staintools.functional.optimization.sparse_solver.ista_step(z: Tensor, hessian: Tensor, b: Tensor, alpha: Tensor, lr: Tensor, positive: bool) Tensor
- Parameters:
z – code. B x num_pixels x num_stain
x (#) – OD space. B x num_pixels x num_channel
weight (#) – init from stain matrix –> B x num_channel x num_stain
hessian – precomputed wtw
b – precomputed xw
alpha – tensor form of the ista penalizer
lr – tensor form of step size
positive – if force z to be positive
Returns:
- torch_staintools.functional.optimization.sparse_solver.rss_grad(z_k: Tensor, x: Tensor, weight: Tensor)
- torch_staintools.functional.optimization.sparse_solver.rss_grad_fast(z_k: Tensor, hessian: Tensor, b: Tensor)
- torch_staintools.functional.optimization.sparse_solver.softshrink(x: Tensor, lambd: Tensor, positive: bool) Tensor
torch_staintools.functional.optimization.sparse_util module
- torch_staintools.functional.optimization.sparse_util.collate_params(x: Tensor, lr: float | Tensor | None, weight: Tensor, alpha: float | Tensor) Tuple[Tensor, Tensor]
- torch_staintools.functional.optimization.sparse_util.initialize_code(x: Tensor, weight: Tensor, mode: Literal['zero', 'transpose', 'unif', 'ridge'], rng: Generator)
Code initialization in dictionary learning.
The dictionary learning is to find the sparse decomposition of data X = D * A, wherein D is the dictionary and A is the code. For ridge initialization, the L2 penalty is customized with constants.INIT_RIDGE_L2
- Parameters:
x – data. B x num_pixel x num_channel
weight – dictionary. B x num_channel x num_stain. Essentially the transposed stain mat
mode – code initialization method
rng – torch.Generator for random initialization modes
Returns:
- torch_staintools.functional.optimization.sparse_util.initialize_dict(shape: Tuple[int, ...], *, device: device | str, rng: Generator, norm_dim: int, positive_dict: bool)
- torch_staintools.functional.optimization.sparse_util.lipschitz_constant(w: Tensor)
find the Lipschitz constant to compute the learning rate in ISTA
- Parameters:
w – weights w in f(z) = ||Wz - x||^2
Returns:
- torch_staintools.functional.optimization.sparse_util.ridge(b: Tensor, a: Tensor, alpha: float | None = None)
- torch_staintools.functional.optimization.sparse_util.to_tensor(v: float | Tensor, like: Tensor) Tensor
- torch_staintools.functional.optimization.sparse_util.validate_code(algorithm: Literal['ista', 'cd', 'fista'], init: Literal['zero', 'transpose', 'unif', 'ridge'] | None, z0: Tensor | None, x: Tensor, weight: Tensor, rng)