torch_staintools.functional.optimization package
Submodules
torch_staintools.functional.optimization.dict_learning module
code directly adapted from https://github.com/rfeinman/pytorch-lasso
- torch_staintools.functional.optimization.dict_learning._ls_batch(od_flatten, stain_matrix)
- Use least square to solve the factorization for concentration. - Warning - May fail on GPU for individual large input (e.g., 1000 x 1000), regardless of batch size. Better for multiple small inputs in terms of H and W. - Parameters:
- od_flatten – B * (HW) x num_input_channel 
- stain_matrix – B x num_stains x num_input_channel 
 
- Returns:
- concentration B x num_stains x (HW) 
 
- torch_staintools.functional.optimization.dict_learning.dict_evaluate(x: Tensor, weight: Tensor, alpha: float, rng: Generator, **kwargs)
- torch_staintools.functional.optimization.dict_learning.dict_learning(x, n_components, *, alpha=1.0, constrained=True, persist=False, lambd=0.01, steps=60, device='cpu', progbar=True, rng: Generator | None = None, **solver_kwargs)
- torch_staintools.functional.optimization.dict_learning.get_concentration_batch(od_flatten, stain_matrix, regularizer, algorithm, rng)
- torch_staintools.functional.optimization.dict_learning.get_concentration_one_by_one(od_flatten, stain_matrix, regularizer, algorithm, rng)
- torch_staintools.functional.optimization.dict_learning.get_concentrations(image, stain_matrix, regularizer=0.01, algorithm: Literal['ista', 'cd', 'ls'] = 'ista', rng: Generator | None = None)
- Estimate concentration matrix given an image and stain matrix. - Warning - algorithm = ‘ls’ May fail on GPU for individual large input (e.g., 1000 x 1000), regardless of batch size. Better for multiple small inputs in terms of H and W. - Parameters:
- image – batched image(s) in shape of BxCxHxW 
- stain_matrix – B x num_stain x input channel 
- regularizer – regularization term if ISTA algorithm is used 
- algorithm – which method to compute the concentration: Solve min||HExC - OD||p support ‘ista’, ‘cd’, and ‘ls’. ‘ls’ simply solves the least square problem for factorization of min||HExC - OD||F (Frobenius norm) but is faster. ‘ista’/cd enforce the sparse penalty (L1 norm) but slower. 
- rng – torch.Generator for random initializations 
 
- Returns:
- B x num_stains x num_pixel_in_tissue_mask 
- Return type:
- concentration matrix 
 
- torch_staintools.functional.optimization.dict_learning.get_concentrations_single(od_flatten, stain_matrix, regularizer=0.01, method: Literal['ista', 'cd', 'ls'] = 'ista', rng: Generator | None = None)
- Helper function to estimate concentration matrix given an image and stain matrix with shape: 2 x (H*W) - For solvers without batch support. Inputs are individual data points from a batch - Parameters:
- od_flatten – Flattened optical density vectors in shape of (H*W) x C (H and W dimensions flattened). 
- stain_matrix – the computed stain matrices in shape of num_stain x input channel 
- regularizer – regularization term if ISTA algorithm is used 
- method – which method to compute the concentration: coordinate descent (‘cd’) or iterative-shrinkage soft thresholding algorithm (‘ista’) 
- rng – torch.Generator for random initializations 
 
- Returns:
- num_stains x num_pixel_in_tissue_mask 
- Return type:
- computed concentration 
 
- torch_staintools.functional.optimization.dict_learning.lasso_loss(X: Tensor, Z: Tensor, weight: Tensor, alpha: float = 1.0) Tensor
- Lasso loss definition. - sum(X-Z)^2 + lambda1 |weight| + (1-alpha) * |weight|2 - Parameters:
- X – 
- Z – 
- weight – 
- alpha – for compatibility purpose. Not used. 
 
- Returns:
- lasso loss 
 
- torch_staintools.functional.optimization.dict_learning.sparse_encode(x: Tensor, weight: Tensor, alpha: float = 0.1, z0=None, algorithm: Literal['ista', 'cd'] = 'ista', init=None, rng: Generator | None = None, **kwargs)
- torch_staintools.functional.optimization.dict_learning.update_dict(dictionary: Tensor, x: Tensor, code: Tensor, positive=True, eps=1e-07, rng: Generator | None = None)
- Update the dense dictionary factor in place. - Modified from _update_dict in sklearn.decomposition._dict_learning - Parameters:
- dictionary – Tensor of shape (n_features, n_components) Value of the dictionary at the previous iteration. 
- x – Tensor of shape (n_samples, n_components) Sparse coding of the data against which to optimize the dictionary. 
- code – Tensor of shape (n_samples, n_components) Sparse coding of the data against which to optimize the dictionary. 
- positive – Whether to enforce positivity when finding the dictionary. 
- eps – Minimum vector norm before considering “degenerate” 
- rng – torch.Generator for initialization of dictionary and code. 
 
 - Returns: 
- torch_staintools.functional.optimization.dict_learning.update_dict_ridge(x, code, lambd=0.0001)
- Update an (unconstrained) dictionary with ridge regression - This is equivalent to a Newton step with the (L2-regularized) squared error objective: f(V) = (1/2N) * ||Vz - x||_2^2 + (lambd/2) * ||V||_2^2 - Parameters:
- x – a batch of observations with shape (n_samples, n_features) 
- code – - a batch of code vectors with shape (n_samples, n_components) 
 
- lambd – weight decay parameter 
 
 - Returns: 
torch_staintools.functional.optimization.solver module
code directly adapted from https://github.com/rfeinman/pytorch-lasso
- torch_staintools.functional.optimization.solver._lipschitz_constant(W)
- find the Lipscitz constant to compute the learning rate in ISTA - Parameters:
- W – weights w in f(z) = ||Wz - x||^2 
 - Returns: 
- torch_staintools.functional.optimization.solver.coord_descent(x, W, z0=None, alpha=1.0, lambda1=0.01, maxiter=1000, tol=1e-06, verbose=False)
- modified coord_descent 
- torch_staintools.functional.optimization.solver.ista(x, z0, weight, alpha=1.0, fast=True, lr='auto', maxiter=50, tol=1e-05, lambda1=0.01, verbose=False, rng: Generator | None = None)
- ISTA solver - Parameters:
- x – data 
- z0 – code, or the initialization mode of the code. 
- weight – dict 
- alpha – eps term for code initialization 
- fast – whether to use FISTA (fast-ista) instead of ISTA 
- lr – learning rate/step size. If auto then it will be specified by the Lipschitz constant of f(z) = ||Wz - x||^2 
- maxiter – max number of iteration if not converge. 
- tol – tolerance term of convergence test. 
- lambda1 – lambda of the sparse terms. 
- verbose – whether to print the progress 
- rng – torch.Generator for random initialization 
 
 - Returns: 
torch_staintools.functional.optimization.sparse_util module
- torch_staintools.functional.optimization.sparse_util.initialize_code(x, weight, alpha, mode, rng: Generator)
- code initialization in dictionary learning. - The dictionary learning is to find the sparse decomposition of data X = D * A, wherein D is the dictionary and A is the code. - Parameters:
- x – data 
- weight – dictionary 
- alpha – small eps term on diagonal for ridge initialization 
- mode – code initialization method 
- rng – torch.Generator for random initialization modes 
 
 - Returns: 
- torch_staintools.functional.optimization.sparse_util.ridge(b: Tensor, A: Tensor, alpha: float = 0.0001)