torch_staintools.normalizer package

Submodules

torch_staintools.normalizer.base module

class torch_staintools.normalizer.base.DataInput

Bases: TypedDict

For future compatibility - e.g., moving average of stain matrix from same wsi which needs uri to identify.

img: ndarray | Tensor | Image
uri: str
class torch_staintools.normalizer.base.Normalizer(cache: TensorCache | None, device: device | None, rng: int | Generator | None)

Bases: CachedRNGModule

Generic normalizer interface with fit/transform, and the forward call that will at least call transform.

Note that the inputs are always supposed to be pytorch tensors in BCHW convention.

classmethod build(*args, **kwargs) Normalizer
abstract fit(*args, **kwargs)
abstract forward(x: DataInput | Tensor, *args, **kwargs)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

abstract transform(x, *args, **kwarags)

torch_staintools.normalizer.factory module

class torch_staintools.normalizer.factory.NormalizerBuilder

Bases: object

Factory Builder for all supported normalizers: reinhard, macenko, and vahadane

static build(method: Literal['reinhard', 'vahadane', 'macenko'], concentration_method: Literal['ista', 'cd', 'ls'] = 'ista', num_stains: int = 2, luminosity_threshold: float = 0.8, regularizer: float = 0.1, rng: int | Generator | None = None, use_cache: bool = False, cache_size_limit: int = -1, device: device | None = None, load_path: str | None = None) Normalizer

build from specified algorithm name method.

Warning

concentration_algorithm = ‘ls’ May fail on GPU for individual large input (e.g., 1000 x 1000), regardless of batch size. Therefore, ‘ls’ is better for multiple small inputs in terms of H and W.

Parameters:
  • method – Name of stain normalization algorithm. Support reinhard, macenko, and vahadane

  • concentration_method – method to obtain the concentration. Default ‘ista’ for fast sparse solution on GPU only applied for StainSeparation-based approaches (macenko and vahadane). support ‘ista’, ‘cd’, and ‘ls’. ‘ls’ simply solves the least square problem for factorization of min||HExC - OD|| but is faster. ‘ista’/cd enforce the sparse penalty but slower.

  • num_stains – number of stains to separate. Currently, Macenko only supports 2. Only applies to macenko and ‘vahadane’ methods.

  • luminosity_threshold – luminosity threshold to ignore the background. None means all regions are considered as tissue. Scale of luminosity threshold is within [0, 1]. Only applies to macenko and ‘vahadane’ methods.

  • regularizer – regularizer term in ISTA for stain separation and concentration computation. Only applies to macenko and ‘vahadane’ methods if ‘ista’ is used.

  • rng – seed or torch.Generator for any random initialization may incur.

  • use_cache – whether to use cache to save the stain matrix of input image to normalize. Only applies to macenko and ‘vahadane’

  • cache_size_limit – size limit of the cache. negative means no limits. Only applies to macenko and ‘vahadane’

  • device – what device to hold the cache and the normalizer. If none the device is set to cpu. Only applies to macenko and ‘vahadane’

  • load_path – If specified, then stain matrix cache will be loaded from the file path. See the cache module for more details. Only applies to macenko and ‘vahadane’

Returns:

torch_staintools.normalizer.reinhard module

class torch_staintools.normalizer.reinhard.ReinhardNormalizer(luminosity_threshold: float | None)

Bases: Normalizer

Very simple Reinhard normalizer.

static _mean_std_helper(image: Tensor, *, mask: Tensor | None = None) Tuple[Tensor, Tensor]

Get the channel-wise mean and std of input

Parameters:
  • image – BCHW scaled to [0, 1] torch.float32. Usually in LAB.

  • mask – luminosity tissue mask of image. Mean and std are only computed within the tissue regions.

Returns:

means,

classmethod build(luminosity_threshold: float | None = None, **kwargs)
fit(image: Tensor)

Fit - compute the means and stds of template in lab space.

Statistics are computed within tissue regions if a luminosity threshold is given to the normalizer upon creation.

Parameters:

image – template. BCHW. [0, 1] torch.float32.

Returns:

forward(x: Tensor, *args, **kwargs)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

luminosity_threshold: float
static normalize_helper(image: Tensor, target_means: Tensor, target_stds: Tensor, mask: Tensor | None = None)

Helper.

Parameters:
  • image – BCHW format. torch.float32 type in range [0, 1].

  • target_means – channel-wise means of template

  • target_stds – channel-wise stds of template

  • mask – Optional luminosity tissue mask to compute the stats within masked region

Returns:

target_means: Tensor
target_stds: Tensor
transform(x: Tensor, *args, **kwargs)

Normalize by (input-mean_input) * (target_std/input_std) + target_mean

Performed in LAB space. Output is convert back to RGB

Parameters:
  • x – input tensor

  • *args – for compatibility of interface.

  • **kwargs – for compatibility of interface.

Returns:

output torch.float32 RGB in range [0, 1] and shape BCHW

torch_staintools.normalizer.separation module

Note that some of the codes are derived from torchvahadane and staintools

class torch_staintools.normalizer.separation.StainSeparation(get_stain_matrix: BaseExtractor, concentration_method: Literal['ista', 'cd', 'ls'] = 'ista', num_stains: int = 2, luminosity_threshold: float = 0.8, regularizer: float = 0.1, rng: int | Generator | None = None, cache: TensorCache | None = None, device: device | None = None)

Bases: Normalizer

Stain Separation-based normalizer’s interface: Macenko and Vahadane

The stain matrix of the reference image (i.e., target image) will be dumped to the state_dict should torch.save(). is used to export the normalizer’s state dict.

Warning

concentration_algorithm = ‘ls’ May fail on GPU for individual large input (e.g., 1000 x 1000), regardless of batch size. Therefore, ‘ls’ is better for multiple small inputs in terms of H and W.

classmethod build(method: str, concentration_method: Literal['ista', 'cd', 'ls'] = 'ista', num_stains: int = 2, luminosity_threshold: float = 0.8, regularizer: float = 0.1, rng: int | Generator | None = None, use_cache: bool = False, cache_size_limit: int = -1, device: device | None = None, load_path: str | None = None) StainSeparation

Builder.

Parameters:
  • method – method of stain extractor name: vadahane or macenko

  • concentration_method – method to obtain the concentration. default ista for computational efficiency on GPU. support ‘ista’, ‘cd’, and ‘ls’. ‘ls’ simply solves the least square problem for factorization of min||HExC - OD|| but is faster. ‘ista’/cd enforce the sparse penalty but slower.

  • num_stains – number of stains to separate. Currently, Macenko only supports 2. In general cases it is recommended to set num_stains as 2.

  • luminosity_threshold – luminosity threshold to ignore the background. None means all regions are considered as tissue.

  • regularizer – regularizer term in ista for stain separation and concentration computation.

  • rng – seed or torch.Generator for any random initialization might incur.

  • use_cache – whether to use cache to save the stain matrix of input image to normalize

  • cache_size_limit – size limit of the cache. negative means no limits.

  • device – what device to hold the cache and the normalizer. If none the device is set to cpu.

  • load_path – If specified, then stain matrix cache will be loaded from the file path. See the cache module for more details.

Returns:

StainSeparation normalizer.

concentration_method: Literal['ista', 'cd', 'ls']
fit(target, concentration_method: Literal['ista', 'cd', 'ls'] | None = None, **stainmat_kwargs)

Fit to a target image.

Note that the stain matrices are registered into buffers so that it’s move to specified device along with the nn.Module object.

Parameters:
  • target – BCHW. Assume it’s cast to torch.float32 and scaled to [0, 1]

  • concentration_method – method to obtain concentration. Use the self.concentration_method if not specified in the signature.

  • **stainmat_kwargs – Extra keyword argument of stain seperator, besides the num_stains/luminosity_threshold that are set in the __init__

Returns:

forward(x: Tensor, cache_keys: List[Hashable] | None = None, **stain_mat_kwargs) Tensor
Parameters:
  • x – input batch image tensor in shape of BxCxHxW

  • cache_keys – unique keys point the input batch to the cached stain matrices. None means no cache.

  • **stain_mat_kwargs – Other keyword arguments for stain matrix estimators than those defined in __init__, i.e., luminosity_threshold, regularizer, and num_stains.

Returns:

normalized output in BxCxHxW shape and float32 dtype. Note that some pixel value may exceed [0, 1] and therefore a clipping operation is applied.

Return type:

torch.Tensor

get_stain_matrix: BaseExtractor
num_stains: int
regularizer: float
static repeat_stain_mat(stain_mat: Tensor, image: Tensor) Tensor

Helper function for vectorization and broadcasting

Parameters:
  • stain_mat – a (usually source) stain matrix obtained from fitting

  • image – input batch image

Returns:

repeated stain matrix

rng: Generator
stain_matrix_target: Tensor
target_concentrations: Tensor
transform(image: Tensor, cache_keys: List[Hashable] | None = None, **stain_mat_kwargs) Tensor

Transformation operation.

Stain matrix is extracted from source image use specified stain seperator (dict learning or svd) Target concentration is by default computed by dict learning for both macenko and vahadane, same as staintools. Normalize the concentration and reconstruct image to OD.

Parameters:
  • image – Image input must be BxCxHxW cast to torch.float32 and rescaled to [0, 1] Check torchvision.transforms.convert_image_dtype.

  • cache_keys – unique keys point the input batch to the cached stain matrices. None means no cache.

  • **stain_mat_kwargs – Extra keyword argument of stain seperator besides the num_stains and luminosity_threshold that was already set in __init__. For instance, in Macenko, an angular percentile argument “perc” may be selected to separate the angles of OD vector projected on SVD and the x-positive axis.

Returns:

normalized output in BxCxHxW shape and float32 dtype. Note that some pixel value may exceed [0, 1] and therefore a clipping operation is applied.

Return type:

torch.Tensor

Module contents