torch_staintools.base_module package
Submodules
torch_staintools.base_module.base module
- class torch_staintools.base_module.base.CachedRNGModule(cache: TensorCache | None, device: device | None, rng: int | Generator | None)
- Bases: - Module- Optionally cache the stain matrices and manage the rng - Note that using nn.Module.to(device) to move the module across GPU/cpu device will reset the states. - CACHE_FIELD: str = '_tensor_cache'
 - static _init_cache(use_cache: bool, cache_size_limit: int, device: device | None = None, load_path: str | None = None) TensorCache | None
 - static _rng_to(rng: Generator | None, device: device)
 - _tensor_cache: TensorCache
 - _tensor_cache_helper() TensorCache | None
 - abstract classmethod build(*args, **kwargs)
 - cache_initialized()
 - property cache_size_limit: int
 - device: device
 - dump_cache(path: str)
 - rng: Generator | None
 - property tensor_cache: TensorCache | None
 - tensor_from_cache(*, cache_keys: List[Hashable] | None, func_partial: Callable, target) Tensor
 - static tensor_from_cache_helper(cache: TensorCache, *, cache_keys: List[Hashable], func_partial: Callable, target) Tensor
 - to(device: device)
- Moves and/or casts the parameters and buffers. - This can be called as - to(device=None, dtype=None, non_blocking=False)
 - to(dtype, non_blocking=False)
 - to(tensor, non_blocking=False)
 - to(memory_format=torch.channels_last)
 - Its signature is similar to - torch.Tensor.to(), but only accepts floating point or complex- dtypes. In addition, this method will only cast the floating point or complex parameters and buffers to- dtype(if given). The integral parameters and buffers will be moved- device, if that is given, but with dtypes unchanged. When- non_blockingis set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.- See below for examples. - Note - This method modifies the module in-place. - Parameters:
- device ( - torch.device) – the desired device of the parameters and buffers in this module
- dtype ( - torch.dtype) – the desired floating point or complex dtype of the parameters and buffers in this module
- tensor (torch.Tensor) – Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module 
- memory_format ( - torch.memory_format) – the desired memory format for 4D parameters and buffers in this module (keyword only argument)
 
- Returns:
- self 
- Return type:
- Module 
 - Examples: - >>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) >>> gpu1 = torch.device("cuda:1") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device("cpu") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)