torch_staintools.functional.optimization package
Submodules
torch_staintools.functional.optimization.dict_learning module
code directly adapted from https://github.com/rfeinman/pytorch-lasso
- torch_staintools.functional.optimization.dict_learning.dict_learning(x: Tensor, n_components: int, algorithm: Literal['ista', 'cd', 'fista'], *, alpha: float, lambd_ridge: float, steps: int, rng: Generator | None, init: str | None, lr: float | None, maxiter: int, tol: float)
- torch_staintools.functional.optimization.dict_learning.dict_learning_loop(x: Tensor, z0: Tensor, weight: Tensor, alpha: Tensor, algorithm: Literal['ista', 'cd', 'fista'], *, lambd_ridge: float, steps: int, rng: Generator, init: str | None, lr: Tensor, maxiter: int, tol: float)
- torch_staintools.functional.optimization.dict_learning.sparse_code(x: Tensor, weight: Tensor, alpha: Tensor, z0: Tensor, algorithm: Literal['ista', 'cd', 'fista'], lr: Tensor, maxiter: int, tol: float, positive_code: bool)
- torch_staintools.functional.optimization.dict_learning.update_dict_cd(dictionary: Tensor, x: Tensor, code: Tensor, positive: bool = True, dead_thresh=1e-07, rng: Generator = None) Tuple[Tensor, Tensor]
Update the dictionary (stain matrix) using Block Coordinate Descent algorithm.
Can satisfy the positive constraint of dictionaries if specified. Side effects: code is updated inplace.
- Parameters:
dictionary – Tensor of shape (n_features, n_components). Value of the dictionary at the previous iteration.
x – Tensor of shape (n_samples, n_components) Sparse coding of the data against which to optimize the dictionary.
code – Tensor of shape (n_samples, n_components) Sparse coding of the data against which to optimize the dictionary.
positive – Whether to enforce positivity when finding the dictionary.
dead_thresh – Minimum vector norm before considering “degenerate”
rng – torch.Generator for initialization of dictionary and code.
- Returns:
torch.Tensor, torch.Tensor, corresponding to the weight and the updated code.
- torch_staintools.functional.optimization.dict_learning.update_dict_ridge(x: Tensor, code: Tensor, lambd: float) Tuple[Tensor, Tensor]
Update an (unconstrained) dictionary with ridge regression
This is equivalent to a Newton step with the (L2-regularized) squared. May have severe numerical stability issues compared to update_dict_cd. error objective: f(V) = (1/2N) * ||Vz - x||_2^2 + (lambd/2) * ||V||_2^2
- Parameters:
x – a batch of observations with shape (n_samples, n_features)
code –
a batch of code vectors with shape (n_samples, n_components)
lambd – weight decay parameter
- Returns:
torch.Tensor, torch.Tensor, corresponding to the weight and the unmodified code.
torch_staintools.functional.optimization.solver module
code directly adapted from https://github.com/rfeinman/pytorch-lasso
- torch_staintools.functional.optimization.solver.coord_descent(x: Tensor, z0: Tensor, weight: Tensor, alpha: Tensor, maxiter: int, tol: float, positive_code: bool)
modified coord_descent
- torch_staintools.functional.optimization.solver.fista(x: Tensor, z0: Tensor, weight: Tensor, alpha: Tensor, lr: Tensor, maxiter: int, tol: float, positive_code)
Fast ISTA solver
- Parameters:
x – data
z0 – code, or the initialization mode of the code.
weight – dict
alpha – penalty term for code
lr – learning rate/step size. If auto then it will be specified by the Lipschitz constant of f(z) = ||Wz - x||^2
maxiter – max number of iteration if not converge.
tol – tolerance term of convergence test.
positive_code – whether enforce the positive z constraint
Returns:
- torch_staintools.functional.optimization.solver.fista_loop(z: Tensor, hessian: Tensor, b: Tensor, alpha: Tensor, lr: Tensor, tol: float, maxiter: int, positive_code: bool = True) Tensor
FISTA Loop
- Parameters:
z – Initial guess
x (#) – Data input (OD space)
weight (#) – Dictionary matrix
hessian – precomputed wtw
b – precomputed xw
alpha – Regularization strength
lr – Learning rate
maxiter – Maximum iterations
tol – Convergence tolerance
positive_code
- torch_staintools.functional.optimization.solver.fista_step(z: Tensor, y: Tensor, t: Tensor, hessian: Tensor, b: Tensor, alpha: Tensor, lr: Tensor, positive_code: bool, tol: float)
- torch_staintools.functional.optimization.solver.ista(x: Tensor, z0: Tensor, weight: Tensor, alpha: Tensor, lr: Tensor, maxiter: int, tol: float, positive_code: bool)
ISTA solver
- Parameters:
x – data
z0 – code, or the initialization mode of the code.
weight – dict
alpha – penalty term for code
lr – learning rate/step size. If auto then it will be specified by the Lipschitz constant of f(z) = ||Wz - x||^2
maxiter – max number of iteration if not converge.
tol – tolerance term of convergence test.
positive_code – whether enforce the positive z constraint
Returns:
- torch_staintools.functional.optimization.solver.ista_loop(z: Tensor, hessian: Tensor, b: Tensor, alpha: Tensor, lr: Tensor, tol: float, maxiter: int, positive_code: bool)
- torch_staintools.functional.optimization.solver.ista_step(z: Tensor, hessian: Tensor, b: Tensor, alpha: Tensor, lr: Tensor, positive: bool) Tensor
- Parameters:
z – code. num_pixels x num_stain
x (#) – OD space. num_pixels x num_channel
weight (#) – init from stain matrix –> num_channel x num_stain
hessian – precomputed wtw
b – precomputed xw
alpha – tensor form of the ista penalizer
lr – tensor form of step size
positive – if force z to be positive
Returns:
- torch_staintools.functional.optimization.solver.rss_grad(z_k: Tensor, x: Tensor, weight: Tensor)
- torch_staintools.functional.optimization.solver.rss_grad_fast(z_k: Tensor, hessian: Tensor, b: Tensor)
- torch_staintools.functional.optimization.solver.softshrink(x: Tensor, lambd: Tensor) Tensor
torch_staintools.functional.optimization.sparse_util module
- torch_staintools.functional.optimization.sparse_util.as_scalar(v: float | Tensor, like: Tensor) Tensor
- torch_staintools.functional.optimization.sparse_util.collate_params(z0: Tensor, x: Tensor, lr: float | Tensor | None, weight: Tensor, alpha: float | Tensor, tol: float) Tuple[Tensor, Tensor, float]
- torch_staintools.functional.optimization.sparse_util.initialize_code(x: Tensor, weight: Tensor, mode: Literal['zero', 'transpose', 'unif', 'ridge'], rng: Generator)
code initialization in dictionary learning.
The dictionary learning is to find the sparse decomposition of data X = D * A, wherein D is the dictionary and A is the code. For ridge initialization, the L2 penalty is customized with constants.INIT_RIDGE_L2 :param x: data :param weight: dictionary :param mode: code initialization method :param rng: torch.Generator for random initialization modes
Returns:
- torch_staintools.functional.optimization.sparse_util.initialize_dict(n_features: int, n_components: int, device: device | str, rng: Generator, positive_dict: bool)
- torch_staintools.functional.optimization.sparse_util.lipschitz_constant(w: Tensor)
find the Lipschitz constant to compute the learning rate in ISTA
- Parameters:
w – weights w in f(z) = ||Wz - x||^2
Returns:
- torch_staintools.functional.optimization.sparse_util.ridge(b: Tensor, a: Tensor, alpha: float | None = None)
- torch_staintools.functional.optimization.sparse_util.validate_code(algorithm: Literal['ista', 'cd', 'fista'], init: Literal['zero', 'transpose', 'unif', 'ridge'] | None, z0: Tensor | None, x: Tensor, weight: Tensor, rng)