torch_staintools.functional.optimization package

Submodules

torch_staintools.functional.optimization.dict_learning module

code directly adapted from https://github.com/rfeinman/pytorch-lasso

torch_staintools.functional.optimization.dict_learning.dict_learning(x: Tensor, n_components: int, algorithm: Literal['ista', 'cd', 'fista'], *, alpha: float, lambd_ridge: float, steps: int, rng: Generator | None, init: str | None, lr: float | None, maxiter: int, tol: float)
torch_staintools.functional.optimization.dict_learning.dict_learning_loop(x: Tensor, z0: Tensor, weight: Tensor, alpha: Tensor, algorithm: Literal['ista', 'cd', 'fista'], *, lambd_ridge: float, steps: int, rng: Generator, init: str | None, lr: Tensor, maxiter: int, tol: float)
torch_staintools.functional.optimization.dict_learning.sparse_code(x: Tensor, weight: Tensor, alpha: Tensor, z0: Tensor, algorithm: Literal['ista', 'cd', 'fista'], lr: Tensor, maxiter: int, tol: float, positive_code: bool)
torch_staintools.functional.optimization.dict_learning.update_dict_cd(dictionary: Tensor, x: Tensor, code: Tensor, positive: bool = True, dead_thresh=1e-07, rng: Generator = None) Tuple[Tensor, Tensor]

Update the dictionary (stain matrix) using Block Coordinate Descent algorithm.

Can satisfy the positive constraint of dictionaries if specified. Side effects: code is updated inplace.

Parameters:
  • dictionary – Tensor of shape (n_features, n_components). Value of the dictionary at the previous iteration.

  • x – Tensor of shape (n_samples, n_components) Sparse coding of the data against which to optimize the dictionary.

  • code – Tensor of shape (n_samples, n_components) Sparse coding of the data against which to optimize the dictionary.

  • positive – Whether to enforce positivity when finding the dictionary.

  • dead_thresh – Minimum vector norm before considering “degenerate”

  • rng – torch.Generator for initialization of dictionary and code.

Returns:

torch.Tensor, torch.Tensor, corresponding to the weight and the updated code.

torch_staintools.functional.optimization.dict_learning.update_dict_ridge(x: Tensor, code: Tensor, lambd: float) Tuple[Tensor, Tensor]

Update an (unconstrained) dictionary with ridge regression

This is equivalent to a Newton step with the (L2-regularized) squared. May have severe numerical stability issues compared to update_dict_cd. error objective: f(V) = (1/2N) * ||Vz - x||_2^2 + (lambd/2) * ||V||_2^2

Parameters:
  • x – a batch of observations with shape (n_samples, n_features)

  • code

    1. a batch of code vectors with shape (n_samples, n_components)

  • lambd – weight decay parameter

Returns:

torch.Tensor, torch.Tensor, corresponding to the weight and the unmodified code.

torch_staintools.functional.optimization.solver module

code directly adapted from https://github.com/rfeinman/pytorch-lasso

torch_staintools.functional.optimization.solver.coord_descent(x: Tensor, z0: Tensor, weight: Tensor, alpha: Tensor, maxiter: int, tol: float, positive_code: bool)

modified coord_descent

torch_staintools.functional.optimization.solver.fista(x: Tensor, z0: Tensor, weight: Tensor, alpha: Tensor, lr: Tensor, maxiter: int, tol: float, positive_code)

Fast ISTA solver

Parameters:
  • x – data

  • z0 – code, or the initialization mode of the code.

  • weight – dict

  • alpha – penalty term for code

  • lr – learning rate/step size. If auto then it will be specified by the Lipschitz constant of f(z) = ||Wz - x||^2

  • maxiter – max number of iteration if not converge.

  • tol – tolerance term of convergence test.

  • positive_code – whether enforce the positive z constraint

Returns:

torch_staintools.functional.optimization.solver.fista_loop(z: Tensor, hessian: Tensor, b: Tensor, alpha: Tensor, lr: Tensor, tol: float, maxiter: int, positive_code: bool = True) Tensor

FISTA Loop

Parameters:
  • z – Initial guess

  • x (#) – Data input (OD space)

  • weight (#) – Dictionary matrix

  • hessian – precomputed wtw

  • b – precomputed xw

  • alpha – Regularization strength

  • lr – Learning rate

  • maxiter – Maximum iterations

  • tol – Convergence tolerance

  • positive_code

torch_staintools.functional.optimization.solver.fista_step(z: Tensor, y: Tensor, t: Tensor, hessian: Tensor, b: Tensor, alpha: Tensor, lr: Tensor, positive_code: bool, tol: float)
torch_staintools.functional.optimization.solver.ista(x: Tensor, z0: Tensor, weight: Tensor, alpha: Tensor, lr: Tensor, maxiter: int, tol: float, positive_code: bool)

ISTA solver

Parameters:
  • x – data

  • z0 – code, or the initialization mode of the code.

  • weight – dict

  • alpha – penalty term for code

  • lr – learning rate/step size. If auto then it will be specified by the Lipschitz constant of f(z) = ||Wz - x||^2

  • maxiter – max number of iteration if not converge.

  • tol – tolerance term of convergence test.

  • positive_code – whether enforce the positive z constraint

Returns:

torch_staintools.functional.optimization.solver.ista_loop(z: Tensor, hessian: Tensor, b: Tensor, alpha: Tensor, lr: Tensor, tol: float, maxiter: int, positive_code: bool)
torch_staintools.functional.optimization.solver.ista_step(z: Tensor, hessian: Tensor, b: Tensor, alpha: Tensor, lr: Tensor, positive: bool) Tensor
Parameters:
  • z – code. num_pixels x num_stain

  • x (#) – OD space. num_pixels x num_channel

  • weight (#) – init from stain matrix –> num_channel x num_stain

  • hessian – precomputed wtw

  • b – precomputed xw

  • alpha – tensor form of the ista penalizer

  • lr – tensor form of step size

  • positive – if force z to be positive

Returns:

torch_staintools.functional.optimization.solver.rss_grad(z_k: Tensor, x: Tensor, weight: Tensor)
torch_staintools.functional.optimization.solver.rss_grad_fast(z_k: Tensor, hessian: Tensor, b: Tensor)
torch_staintools.functional.optimization.solver.softshrink(x: Tensor, lambd: Tensor) Tensor

torch_staintools.functional.optimization.sparse_util module

torch_staintools.functional.optimization.sparse_util.as_scalar(v: float | Tensor, like: Tensor) Tensor
torch_staintools.functional.optimization.sparse_util.collate_params(z0: Tensor, x: Tensor, lr: float | Tensor | None, weight: Tensor, alpha: float | Tensor, tol: float) Tuple[Tensor, Tensor, float]
torch_staintools.functional.optimization.sparse_util.initialize_code(x: Tensor, weight: Tensor, mode: Literal['zero', 'transpose', 'unif', 'ridge'], rng: Generator)

code initialization in dictionary learning.

The dictionary learning is to find the sparse decomposition of data X = D * A, wherein D is the dictionary and A is the code. For ridge initialization, the L2 penalty is customized with constants.INIT_RIDGE_L2 :param x: data :param weight: dictionary :param mode: code initialization method :param rng: torch.Generator for random initialization modes

Returns:

torch_staintools.functional.optimization.sparse_util.initialize_dict(n_features: int, n_components: int, device: device | str, rng: Generator, positive_dict: bool)
torch_staintools.functional.optimization.sparse_util.lipschitz_constant(w: Tensor)

find the Lipschitz constant to compute the learning rate in ISTA

Parameters:

w – weights w in f(z) = ||Wz - x||^2

Returns:

torch_staintools.functional.optimization.sparse_util.ridge(b: Tensor, a: Tensor, alpha: float | None = None)
torch_staintools.functional.optimization.sparse_util.validate_code(algorithm: Literal['ista', 'cd', 'fista'], init: Literal['zero', 'transpose', 'unif', 'ridge'] | None, z0: Tensor | None, x: Tensor, weight: Tensor, rng)

Module contents