torch_staintools.functional.optimization package

Submodules

torch_staintools.functional.optimization.dict_learning module

code directly adapted from https://github.com/rfeinman/pytorch-lasso

torch_staintools.functional.optimization.dict_learning._ls_batch(od_flatten, stain_matrix)

Use least square to solve the factorization for concentration.

Warning

May fail on GPU for individual large input (e.g., 1000 x 1000), regardless of batch size. Better for multiple small inputs in terms of H and W.

Parameters:
  • od_flatten – B * (HW) x num_input_channel

  • stain_matrix – B x num_stains x num_input_channel

Returns:

concentration B x num_stains x (HW)

torch_staintools.functional.optimization.dict_learning.dict_evaluate(x: Tensor, weight: Tensor, alpha: float, rng: Generator, **kwargs)
torch_staintools.functional.optimization.dict_learning.dict_learning(x, n_components, *, alpha=1.0, constrained=True, persist=False, lambd=0.01, steps=60, device='cpu', progbar=True, rng: Generator | None = None, **solver_kwargs)
torch_staintools.functional.optimization.dict_learning.get_concentration_batch(od_flatten, stain_matrix, regularizer, algorithm, rng)
torch_staintools.functional.optimization.dict_learning.get_concentration_one_by_one(od_flatten, stain_matrix, regularizer, algorithm, rng)
torch_staintools.functional.optimization.dict_learning.get_concentrations(image, stain_matrix, regularizer=0.01, algorithm: Literal['ista', 'cd', 'ls'] = 'ista', rng: Generator | None = None)

Estimate concentration matrix given an image and stain matrix.

Warning

algorithm = ‘ls’ May fail on GPU for individual large input (e.g., 1000 x 1000), regardless of batch size. Better for multiple small inputs in terms of H and W.

Parameters:
  • image – batched image(s) in shape of BxCxHxW

  • stain_matrix – B x num_stain x input channel

  • regularizer – regularization term if ISTA algorithm is used

  • algorithm – which method to compute the concentration: Solve min||HExC - OD||p support ‘ista’, ‘cd’, and ‘ls’. ‘ls’ simply solves the least square problem for factorization of min||HExC - OD||F (Frobenius norm) but is faster. ‘ista’/cd enforce the sparse penalty (L1 norm) but slower.

  • rng – torch.Generator for random initializations

Returns:

B x num_stains x num_pixel_in_tissue_mask

Return type:

concentration matrix

torch_staintools.functional.optimization.dict_learning.get_concentrations_single(od_flatten, stain_matrix, regularizer=0.01, method: Literal['ista', 'cd', 'ls'] = 'ista', rng: Generator | None = None)

Helper function to estimate concentration matrix given an image and stain matrix with shape: 2 x (H*W)

For solvers without batch support. Inputs are individual data points from a batch

Parameters:
  • od_flatten – Flattened optical density vectors in shape of (H*W) x C (H and W dimensions flattened).

  • stain_matrix – the computed stain matrices in shape of num_stain x input channel

  • regularizer – regularization term if ISTA algorithm is used

  • method – which method to compute the concentration: coordinate descent (‘cd’) or iterative-shrinkage soft thresholding algorithm (‘ista’)

  • rng – torch.Generator for random initializations

Returns:

num_stains x num_pixel_in_tissue_mask

Return type:

computed concentration

torch_staintools.functional.optimization.dict_learning.lasso_loss(X: Tensor, Z: Tensor, weight: Tensor, alpha: float = 1.0) Tensor

Lasso loss definition.

sum(X-Z)^2 + lambda1 |weight| + (1-alpha) * |weight|2

Parameters:
  • X

  • Z

  • weight

  • alpha – for compatibility purpose. Not used.

Returns:

lasso loss

torch_staintools.functional.optimization.dict_learning.sparse_encode(x: Tensor, weight: Tensor, alpha: float = 0.1, z0=None, algorithm: Literal['ista', 'cd'] = 'ista', init=None, rng: Generator | None = None, **kwargs)
torch_staintools.functional.optimization.dict_learning.update_dict(dictionary: Tensor, x: Tensor, code: Tensor, positive=True, eps=1e-07, rng: Generator | None = None)

Update the dense dictionary factor in place.

Modified from _update_dict in sklearn.decomposition._dict_learning

Parameters:
  • dictionary – Tensor of shape (n_features, n_components) Value of the dictionary at the previous iteration.

  • x – Tensor of shape (n_samples, n_components) Sparse coding of the data against which to optimize the dictionary.

  • code – Tensor of shape (n_samples, n_components) Sparse coding of the data against which to optimize the dictionary.

  • positive – Whether to enforce positivity when finding the dictionary.

  • eps – Minimum vector norm before considering “degenerate”

  • rng – torch.Generator for initialization of dictionary and code.

Returns:

torch_staintools.functional.optimization.dict_learning.update_dict_ridge(x, code, lambd=0.0001)

Update an (unconstrained) dictionary with ridge regression

This is equivalent to a Newton step with the (L2-regularized) squared error objective: f(V) = (1/2N) * ||Vz - x||_2^2 + (lambd/2) * ||V||_2^2

Parameters:
  • x – a batch of observations with shape (n_samples, n_features)

  • code

    1. a batch of code vectors with shape (n_samples, n_components)

  • lambd – weight decay parameter

Returns:

torch_staintools.functional.optimization.solver module

code directly adapted from https://github.com/rfeinman/pytorch-lasso

torch_staintools.functional.optimization.solver._lipschitz_constant(W)

find the Lipscitz constant to compute the learning rate in ISTA

Parameters:

W – weights w in f(z) = ||Wz - x||^2

Returns:

torch_staintools.functional.optimization.solver.coord_descent(x, W, z0=None, alpha=1.0, lambda1=0.01, maxiter=1000, tol=1e-06, verbose=False)

modified coord_descent

torch_staintools.functional.optimization.solver.ista(x, z0, weight, alpha=1.0, fast=True, lr='auto', maxiter=50, tol=1e-05, lambda1=0.01, verbose=False, rng: Generator | None = None)

ISTA solver

Parameters:
  • x – data

  • z0 – code, or the initialization mode of the code.

  • weight – dict

  • alpha – eps term for code initialization

  • fast – whether to use FISTA (fast-ista) instead of ISTA

  • lr – learning rate/step size. If auto then it will be specified by the Lipschitz constant of f(z) = ||Wz - x||^2

  • maxiter – max number of iteration if not converge.

  • tol – tolerance term of convergence test.

  • lambda1 – lambda of the sparse terms.

  • verbose – whether to print the progress

  • rng – torch.Generator for random initialization

Returns:

torch_staintools.functional.optimization.sparse_util module

torch_staintools.functional.optimization.sparse_util.initialize_code(x, weight, alpha, mode, rng: Generator)

code initialization in dictionary learning.

The dictionary learning is to find the sparse decomposition of data X = D * A, wherein D is the dictionary and A is the code.

Parameters:
  • x – data

  • weight – dictionary

  • alpha – small eps term on diagonal for ridge initialization

  • mode – code initialization method

  • rng – torch.Generator for random initialization modes

Returns:

torch_staintools.functional.optimization.sparse_util.ridge(b: Tensor, A: Tensor, alpha: float = 0.0001)

Module contents