torch_staintools.base_module package
Submodules
torch_staintools.base_module.base module
- class torch_staintools.base_module.base.CachedRNGModule(cache: TensorCache | None, device: device | None, rng: int | Generator | None)
Bases:
Module
Optionally cache the stain matrices and manage the rng
Note that using nn.Module.to(device) to move the module across GPU/cpu device will reset the states.
- CACHE_FIELD: str = '_tensor_cache'
- static _init_cache(use_cache: bool, cache_size_limit: int, device: device | None = None, load_path: str | None = None) TensorCache | None
- static _rng_to(rng: Generator | None, device: device)
- _tensor_cache: TensorCache
- _tensor_cache_helper() TensorCache | None
- abstract classmethod build(*args, **kwargs)
- cache_initialized()
- property cache_size_limit: int
- device: device
- dump_cache(path: str)
- rng: Generator | None
- property tensor_cache: TensorCache | None
- tensor_from_cache(*, cache_keys: List[Hashable] | None, func_partial: Callable, target) Tensor
- static tensor_from_cache_helper(cache: TensorCache, *, cache_keys: List[Hashable], func_partial: Callable, target) Tensor
- to(device: device)
Moves and/or casts the parameters and buffers.
This can be called as
- to(device=None, dtype=None, non_blocking=False)
- to(dtype, non_blocking=False)
- to(tensor, non_blocking=False)
- to(memory_format=torch.channels_last)
Its signature is similar to
torch.Tensor.to()
, but only accepts floating point or complexdtype
s. In addition, this method will only cast the floating point or complex parameters and buffers todtype
(if given). The integral parameters and buffers will be moveddevice
, if that is given, but with dtypes unchanged. Whennon_blocking
is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.See below for examples.
Note
This method modifies the module in-place.
- Parameters:
device (
torch.device
) – the desired device of the parameters and buffers in this moduledtype (
torch.dtype
) – the desired floating point or complex dtype of the parameters and buffers in this moduletensor (torch.Tensor) – Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module
memory_format (
torch.memory_format
) – the desired memory format for 4D parameters and buffers in this module (keyword only argument)
- Returns:
self
- Return type:
Module
Examples:
>>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) >>> gpu1 = torch.device("cuda:1") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device("cpu") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)