torch_staintools.functional.optimization package

Submodules

torch_staintools.functional.optimization.dict_learning module

code directly adapted from https://github.com/rfeinman/pytorch-lasso

torch_staintools.functional.optimization.dict_learning.dict_learning(x: Tensor, tissue_mask_flatten: Tensor | None, n_components: int, algorithm: Literal['ista', 'cd', 'fista'], *, alpha: float, lambd_ridge: float, steps: int, rng: Generator | None, init: str | None, lr: float | None, maxiter: int)
torch_staintools.functional.optimization.dict_learning.dict_learning_loop(x: Tensor, z0: Tensor, weight: Tensor, alpha: Tensor, algorithm: Literal['ista', 'cd', 'fista'], *, lambd_ridge: float, steps: int, rng: Generator, init: str | None, lr: Tensor, maxiter: int)
torch_staintools.functional.optimization.dict_learning.sparse_code(x: Tensor, weight: Tensor, alpha: Tensor, z0: Tensor, algorithm: Literal['ista', 'cd', 'fista'], lr: Tensor, maxiter: int, positive_code: bool)
torch_staintools.functional.optimization.dict_learning.update_dict_cd(dictionary: Tensor, x: Tensor, code: Tensor, positive: bool = True, dead_thresh=1e-07, rng: Generator = None) Tuple[Tensor, Tensor]

Update the dictionary (stain matrix) using Block Coordinate Descent algorithm.

Can satisfy the positive constraint of dictionaries if specified. Side effects: code is updated inplace.

Parameters:
  • dictionary – Tensor of shape (B, n_features, n_components). Value of the dictionary at the previous iteration. n_features in this context can be the number of input color channels. n_components in this context can be the number of stain to separate.

  • x – Tensor of shape (B, n_samples, n_features) Sparse coding of the data against which to optimize the dictionary.

  • code – Tensor of shape (B, n_samples, n_components) Sparse coding of the data against which to optimize the dictionary.

  • positive – Whether to enforce positivity when finding the dictionary.

  • dead_thresh – Minimum vector norm before considering “degenerate”

  • rng – torch.Generator for initialization of dictionary and code.

Returns:

torch.Tensor, torch.Tensor, corresponding to the weight and the updated code.

torch_staintools.functional.optimization.dict_learning.update_dict_ridge(x: Tensor, code: Tensor, lambd: float) Tuple[Tensor, Tensor]

Update an (unconstrained) dictionary with ridge regression

This is equivalent to a Newton step with the (L2-regularized) squared. May have severe numerical stability issues compared to update_dict_cd. error objective: f(V) = (1/2N) * ||Vz - x||_2^2 + (lambd/2) * ||V||_2^2

Parameters:
  • x – a batch of observations with shape (B, n_samples, n_features)

  • code

    1. a batch of code vectors with shape (B, n_samples, n_components)

  • lambd – weight decay parameter

Returns:

torch.Tensor, torch.Tensor, corresponding to the weight and the unmodified code.

torch_staintools.functional.optimization.linear_solver module

torch_staintools.functional.optimization.linear_solver.lstsq_solver(od_flatten: Tensor, dictionary: Tensor, positive: bool)

Least squares solver for concentration.

Warning

for cusolver backend, algorithm = ‘ls’ May fail on GPU for individual large input (e.g., 1000 x 1000), regardless of batch size. To use ‘ls’ on large image, consider using magma backend: `torch.backends.cuda.preferred_linalg_library('magma')`

Parameters:
  • od_flatten – B x num_pixels x C

  • dictionary – Transpose of stain matrix. B x C x num_stain

  • positive – enforce positive concentration

Returns:

B x num_pixels x num_stains

Return type:

concentration (flattened)

torch_staintools.functional.optimization.linear_solver.pinv_solver(od_flatten: Tensor, dictionary: Tensor, positive: bool)

Pseudo-inverse solver for concentration.

Parameters:
  • od_flatten – B x num_pixels x C

  • dictionary – Transpose of stain matrix. B x C x num_stains

  • positive – enforce positive concentration

Returns:

B x num_pixels x num_stains

Return type:

concentration (flattened)

torch_staintools.functional.optimization.linear_solver.qr_solver(od_flatten, dictionary, positive: bool)

QR solver for concentration.

Parameters:
  • od_flatten – B x num_pixels x C

  • dictionary – Transpose of stain matrix. B x C x num_stains

  • positive – enforce positive concentration

Returns:

B x num_pixels x num_stains

Return type:

concentration (flattened)

torch_staintools.functional.optimization.linear_solver.qr_solver_generic(od_flatten, dictionary, positive: bool)

QR solver for concentration.

Warning

This is not scalable against the size of individual images compared to the 2-stain unrolled version `qr_solver_two_stain`.

Parameters:
  • od_flatten – B x num_pixels x C

  • dictionary – Transpose of stain matrix. B x C x num_stains

  • positive – enforce positive concentration

Returns:

B x num_pixels x num_stains

Return type:

concentration (flattened)

torch_staintools.functional.optimization.linear_solver.qr_solver_two_stain(od_flatten, dictionary, positive: bool)

QR solver for concentration (hardcoded 2-stain computation)

Parameters:
  • od_flatten – B x num_pixels x C

  • dictionary – Transpose of stain matrix. B x C x num_stains

  • positive – enforce positive concentration

Returns:

B x num_pixels x num_stains

Return type:

concentration (flattened)

torch_staintools.functional.optimization.sparse_solver module

code directly adapted from https://github.com/rfeinman/pytorch-lasso

torch_staintools.functional.optimization.sparse_solver.cd_loop(z: Tensor, b: Tensor, s: Tensor, alpha: Tensor, maxiter: int, positive_code: bool) Tensor
torch_staintools.functional.optimization.sparse_solver.cd_step(z: Tensor, b: Tensor, s: Tensor, alpha: Tensor, positive_code: bool) tuple[Tensor, Tensor]
torch_staintools.functional.optimization.sparse_solver.coord_descent(x: Tensor, z0: Tensor, weight: Tensor, alpha: Tensor, maxiter: int, positive_code: bool)

modified coord_descent

torch_staintools.functional.optimization.sparse_solver.fista(x: Tensor, z0: Tensor, weight: Tensor, alpha: Tensor, lr: Tensor, maxiter: int, positive_code)

Fast ISTA solver

Parameters:
  • x – data

  • z0 – code, or the initialization mode of the code.

  • weight – dict

  • alpha – penalty term for code

  • lr – learning rate/step size. If auto then it will be specified by the Lipschitz constant of f(z) = ||Wz - x||^2

  • maxiter – max number of iteration if not converge.

  • positive_code – whether enforce the positive z constraint

Returns:

torch_staintools.functional.optimization.sparse_solver.fista_loop(z: Tensor, hessian: Tensor, b: Tensor, alpha: Tensor, lr: Tensor, maxiter: int, positive_code: bool = True) Tensor

FISTA Loop

Parameters:
  • z – Initial guess

  • x (#) – Data input (OD space)

  • weight (#) – Dictionary matrix

  • hessian – precomputed wtw

  • b – precomputed xw

  • alpha – Regularization strength

  • lr – Learning rate

  • maxiter – Maximum iterations

  • positive_code

torch_staintools.functional.optimization.sparse_solver.fista_step(z: Tensor, y: Tensor, t: Tensor, hessian: Tensor, b: Tensor, alpha: Tensor, lr: Tensor, positive_code: bool, keep: Tensor)
torch_staintools.functional.optimization.sparse_solver.ista(x: Tensor, z0: Tensor, weight: Tensor, alpha: Tensor, lr: Tensor, maxiter: int, positive_code: bool)

ISTA solver

Parameters:
  • x – data

  • z0 – code, or the initialization mode of the code.

  • weight – dict

  • alpha – penalty term for code

  • lr – learning rate/step size. If auto then it will be specified by the Lipschitz constant of f(z) = ||Wz - x||^2

  • maxiter – max number of iteration if not converge.

  • positive_code – whether enforce the positive z constraint

Returns:

torch_staintools.functional.optimization.sparse_solver.ista_loop(z: Tensor, hessian: Tensor, b: Tensor, *, alpha: Tensor, lr: Tensor, maxiter: int, positive_code: bool)
torch_staintools.functional.optimization.sparse_solver.ista_step(z: Tensor, hessian: Tensor, b: Tensor, alpha: Tensor, lr: Tensor, positive: bool) Tensor
Parameters:
  • z – code. B x num_pixels x num_stain

  • x (#) – OD space. B x num_pixels x num_channel

  • weight (#) – init from stain matrix –> B x num_channel x num_stain

  • hessian – precomputed wtw

  • b – precomputed xw

  • alpha – tensor form of the ista penalizer

  • lr – tensor form of step size

  • positive – if force z to be positive

Returns:

torch_staintools.functional.optimization.sparse_solver.rss_grad(z_k: Tensor, x: Tensor, weight: Tensor)
torch_staintools.functional.optimization.sparse_solver.rss_grad_fast(z_k: Tensor, hessian: Tensor, b: Tensor)
torch_staintools.functional.optimization.sparse_solver.softshrink(x: Tensor, lambd: Tensor, positive: bool) Tensor

torch_staintools.functional.optimization.sparse_util module

torch_staintools.functional.optimization.sparse_util.collate_params(x: Tensor, lr: float | Tensor | None, weight: Tensor, alpha: float | Tensor) Tuple[Tensor, Tensor]
torch_staintools.functional.optimization.sparse_util.initialize_code(x: Tensor, weight: Tensor, mode: Literal['zero', 'transpose', 'unif', 'ridge'], rng: Generator)

Code initialization in dictionary learning.

The dictionary learning is to find the sparse decomposition of data X = D * A, wherein D is the dictionary and A is the code. For ridge initialization, the L2 penalty is customized with constants.INIT_RIDGE_L2

Parameters:
  • x – data. B x num_pixel x num_channel

  • weight – dictionary. B x num_channel x num_stain. Essentially the transposed stain mat

  • mode – code initialization method

  • rng – torch.Generator for random initialization modes

Returns:

torch_staintools.functional.optimization.sparse_util.initialize_dict(shape: Tuple[int, ...], *, device: device | str, rng: Generator, norm_dim: int, positive_dict: bool)
torch_staintools.functional.optimization.sparse_util.lipschitz_constant(w: Tensor)

find the Lipschitz constant to compute the learning rate in ISTA

Parameters:

w – weights w in f(z) = ||Wz - x||^2

Returns:

torch_staintools.functional.optimization.sparse_util.ridge(b: Tensor, a: Tensor, alpha: float | None = None)
torch_staintools.functional.optimization.sparse_util.to_tensor(v: float | Tensor, like: Tensor) Tensor
torch_staintools.functional.optimization.sparse_util.validate_code(algorithm: Literal['ista', 'cd', 'fista'], init: Literal['zero', 'transpose', 'unif', 'ridge'] | None, z0: Tensor | None, x: Tensor, weight: Tensor, rng)

Module contents