torch_staintools.functional.optimization package
Submodules
torch_staintools.functional.optimization.dict_learning module
code directly adapted from https://github.com/rfeinman/pytorch-lasso
- torch_staintools.functional.optimization.dict_learning._ls_batch(od_flatten, stain_matrix)
Use least square to solve the factorization for concentration.
Warning
May fail on GPU for individual large input (e.g., 1000 x 1000), regardless of batch size. Better for multiple small inputs in terms of H and W.
- Parameters:
od_flatten – B * (HW) x num_input_channel
stain_matrix – B x num_stains x num_input_channel
- Returns:
concentration B x num_stains x (HW)
- torch_staintools.functional.optimization.dict_learning.dict_evaluate(x: Tensor, weight: Tensor, alpha: float, rng: Generator, **kwargs)
- torch_staintools.functional.optimization.dict_learning.dict_learning(x, n_components, *, alpha=1.0, constrained=True, persist=False, lambd=0.01, steps=60, device='cpu', progbar=True, rng: Generator | None = None, **solver_kwargs)
- torch_staintools.functional.optimization.dict_learning.get_concentration_batch(od_flatten, stain_matrix, regularizer, algorithm, rng)
- torch_staintools.functional.optimization.dict_learning.get_concentration_one_by_one(od_flatten, stain_matrix, regularizer, algorithm, rng)
- torch_staintools.functional.optimization.dict_learning.get_concentrations(image, stain_matrix, regularizer=0.01, algorithm: Literal['ista', 'cd', 'ls'] = 'ista', rng: Generator | None = None)
Estimate concentration matrix given an image and stain matrix.
Warning
algorithm = ‘ls’ May fail on GPU for individual large input (e.g., 1000 x 1000), regardless of batch size. Better for multiple small inputs in terms of H and W.
- Parameters:
image – batched image(s) in shape of BxCxHxW
stain_matrix – B x num_stain x input channel
regularizer – regularization term if ISTA algorithm is used
algorithm – which method to compute the concentration: Solve min||HExC - OD||p support ‘ista’, ‘cd’, and ‘ls’. ‘ls’ simply solves the least square problem for factorization of min||HExC - OD||F (Frobenius norm) but is faster. ‘ista’/cd enforce the sparse penalty (L1 norm) but slower.
rng – torch.Generator for random initializations
- Returns:
B x num_stains x num_pixel_in_tissue_mask
- Return type:
concentration matrix
- torch_staintools.functional.optimization.dict_learning.get_concentrations_single(od_flatten, stain_matrix, regularizer=0.01, method: Literal['ista', 'cd', 'ls'] = 'ista', rng: Generator | None = None)
Helper function to estimate concentration matrix given an image and stain matrix with shape: 2 x (H*W)
For solvers without batch support. Inputs are individual data points from a batch
- Parameters:
od_flatten – Flattened optical density vectors in shape of (H*W) x C (H and W dimensions flattened).
stain_matrix – the computed stain matrices in shape of num_stain x input channel
regularizer – regularization term if ISTA algorithm is used
method – which method to compute the concentration: coordinate descent (‘cd’) or iterative-shrinkage soft thresholding algorithm (‘ista’)
rng – torch.Generator for random initializations
- Returns:
num_stains x num_pixel_in_tissue_mask
- Return type:
computed concentration
- torch_staintools.functional.optimization.dict_learning.lasso_loss(X: Tensor, Z: Tensor, weight: Tensor, alpha: float = 1.0) Tensor
Lasso loss definition.
sum(X-Z)^2 + lambda1 |weight| + (1-alpha) * |weight|2
- Parameters:
X –
Z –
weight –
alpha – for compatibility purpose. Not used.
- Returns:
lasso loss
- torch_staintools.functional.optimization.dict_learning.sparse_encode(x: Tensor, weight: Tensor, alpha: float = 0.1, z0=None, algorithm: Literal['ista', 'cd'] = 'ista', init=None, rng: Generator | None = None, **kwargs)
- torch_staintools.functional.optimization.dict_learning.update_dict(dictionary: Tensor, x: Tensor, code: Tensor, positive=True, eps=1e-07, rng: Generator | None = None)
Update the dense dictionary factor in place.
Modified from _update_dict in sklearn.decomposition._dict_learning
- Parameters:
dictionary – Tensor of shape (n_features, n_components) Value of the dictionary at the previous iteration.
x – Tensor of shape (n_samples, n_components) Sparse coding of the data against which to optimize the dictionary.
code – Tensor of shape (n_samples, n_components) Sparse coding of the data against which to optimize the dictionary.
positive – Whether to enforce positivity when finding the dictionary.
eps – Minimum vector norm before considering “degenerate”
rng – torch.Generator for initialization of dictionary and code.
Returns:
- torch_staintools.functional.optimization.dict_learning.update_dict_ridge(x, code, lambd=0.0001)
Update an (unconstrained) dictionary with ridge regression
This is equivalent to a Newton step with the (L2-regularized) squared error objective: f(V) = (1/2N) * ||Vz - x||_2^2 + (lambd/2) * ||V||_2^2
- Parameters:
x – a batch of observations with shape (n_samples, n_features)
code –
a batch of code vectors with shape (n_samples, n_components)
lambd – weight decay parameter
Returns:
torch_staintools.functional.optimization.solver module
code directly adapted from https://github.com/rfeinman/pytorch-lasso
- torch_staintools.functional.optimization.solver._lipschitz_constant(W)
find the Lipscitz constant to compute the learning rate in ISTA
- Parameters:
W – weights w in f(z) = ||Wz - x||^2
Returns:
- torch_staintools.functional.optimization.solver.coord_descent(x, W, z0=None, alpha=1.0, lambda1=0.01, maxiter=1000, tol=1e-06, verbose=False)
modified coord_descent
- torch_staintools.functional.optimization.solver.ista(x, z0, weight, alpha=1.0, fast=True, lr='auto', maxiter=50, tol=1e-05, lambda1=0.01, verbose=False, rng: Generator | None = None)
ISTA solver
- Parameters:
x – data
z0 – code, or the initialization mode of the code.
weight – dict
alpha – eps term for code initialization
fast – whether to use FISTA (fast-ista) instead of ISTA
lr – learning rate/step size. If auto then it will be specified by the Lipschitz constant of f(z) = ||Wz - x||^2
maxiter – max number of iteration if not converge.
tol – tolerance term of convergence test.
lambda1 – lambda of the sparse terms.
verbose – whether to print the progress
rng – torch.Generator for random initialization
Returns:
torch_staintools.functional.optimization.sparse_util module
- torch_staintools.functional.optimization.sparse_util.initialize_code(x, weight, alpha, mode, rng: Generator)
code initialization in dictionary learning.
The dictionary learning is to find the sparse decomposition of data X = D * A, wherein D is the dictionary and A is the code.
- Parameters:
x – data
weight – dictionary
alpha – small eps term on diagonal for ridge initialization
mode – code initialization method
rng – torch.Generator for random initialization modes
Returns:
- torch_staintools.functional.optimization.sparse_util.ridge(b: Tensor, A: Tensor, alpha: float = 0.0001)