torch_staintools.base_module package

Submodules

torch_staintools.base_module.base module

class torch_staintools.base_module.base.CachedRNGModule(cache: TensorCache | None, device: device | None, rng: int | Generator | None)

Bases: Module

Optionally cache the stain matrices and manage the rng

Note that using nn.Module.to(device) to move the module across GPU/cpu device will reset the states.

CACHE_FIELD: str = '_tensor_cache'
abstract classmethod build(*args, **kwargs)
cache_initialized()
property cache_size_limit: int
default_hash(cache_keys: List[Hashable] | None, target: Tensor, mask: Tensor) List[Hashable] | None
device: device
dump_cache(path: str)
get_batch(keys: List[Hashable], get_stain_mat: StainExtraction, target: Tensor, mask: Tensor) List[Tensor] | Tensor

Batchified get

The method assumes that the func callable would generate a whole batch of data each time. Might be useful if batchified processing is much faster than individually process all inputs (e.g., cuda tensors processed by nn.Module)

It is a hit only if all keys are cached.

Parameters:
  • keys – list of keys corresponding to the batch input.

  • get_stain_mat – function to generate the data if the corresponding entry is not cached.

  • target – target tensor. Potentially in OD space.

  • mask – mask the background pixel to 0. foreground regions are 1.

Returns:

List of queried results.

load_cache(path: str)
property rng
stain_mat_cached(*, cache_keys: List[Hashable] | None, get_stain_mat: StainExtraction, target: Tensor, mask: Tensor) Tensor
stain_mat_cached_helper(cache_keys: List[Hashable], get_stain_mat: StainExtraction, target: Tensor, mask: Tensor) Tensor
property tensor_cache: TensorCache | None
to(device: device)

Move and/or cast the parameters and buffers.

This can be called as

to(device=None, dtype=None, non_blocking=False)
to(dtype, non_blocking=False)
to(tensor, non_blocking=False)
to(memory_format=torch.channels_last)

Its signature is similar to torch.Tensor.to(), but only accepts floating point or complex dtypes. In addition, this method will only cast the floating point or complex parameters and buffers to dtype (if given). The integral parameters and buffers will be moved device, if that is given, but with dtypes unchanged. When non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.

See below for examples.

Note

This method modifies the module in-place.

Parameters:
  • device (torch.device) – the desired device of the parameters and buffers in this module

  • dtype (torch.dtype) – the desired floating point or complex dtype of the parameters and buffers in this module

  • tensor (torch.Tensor) – Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module

  • memory_format (torch.memory_format) – the desired memory format for 4D parameters and buffers in this module (keyword only argument)

Returns:

self

Return type:

Module

Examples:

>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> linear = nn.Linear(2, 2)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]])
>>> linear.to(torch.double)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]], dtype=torch.float64)
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)
>>> gpu1 = torch.device("cuda:1")
>>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
>>> cpu = torch.device("cpu")
>>> linear.to(cpu)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16)

>>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)
>>> linear.weight
Parameter containing:
tensor([[ 0.3741+0.j,  0.2382+0.j],
        [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)
>>> linear(torch.ones(3, 2, dtype=torch.cdouble))
tensor([[0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)

Module contents