torch_staintools.cache package
Submodules
torch_staintools.cache.base module
- class torch_staintools.cache.base.Cache(size_limit: int)
Bases:
ABC
,Generic
[C
,V
]A simple abstraction of cache.
- __size_limit: int
- _abc_impl = <_abc._abc_data object>
- abstract _dump_helper(path: str)
To implement: dump the cached data to the local file system.
- Parameters:
path – output filename
Returns:
- abstract _new_cache()
- abstract _write_to_cache_helper(key: Hashable, value: V)
Write the data (value) to the given address (key) in the cache
- Parameters:
key – any hashable that points the data to the address in the cache
value – value of the data to cache
Returns:
- abstract classmethod build(*args, **kwargs)
- data_cache: C
- dump(path: str, force_overwrite: bool = False)
Dump the cached data to the local file system.
- Parameters:
path – output filename
force_overwrite – whether to force overwriting the existing file on path
Returns:
- get(key: Hashable, func: Callable | None, *func_args, **func_kwargs)
Get the data cached under key.
If the corresponding data of key is not yet cached, it will be computed by the func(*func_args, **func_kwargs) and the results will be cached if the remaining size is sufficient.
- Parameters:
key – the address of the data in cache
func – callable to evaluate the new data to cache if not yet cached under key
*func_args – positional arguments of func
**func_kwargs – keyword arguments of func
Returns:
- get_batch(keys: List[Hashable], func: Callable | None, *func_args, **func_kwargs) List[V]
Batchified get
The method assumes that the func callable would generate a whole batch of data each time. Might be useful if batchified processing is much faster than individually process all inputs (e.g., cuda tensors processed by nn.Module)
It is a hit only if all keys are cached.
- Parameters:
keys – list of keys corresponding to the batch input.
func – function to generate the data if the corresponding entry is not cached.
*func_args – positional args for the func.
**func_kwargs – keyword args for the func.
- Returns:
List of queried results.
- abstract is_cached(key: Hashable)
whether the key already stores a value.
- Parameters:
key – key to query
- Returns:
bool of whether the corresponding key already stores a value.
- abstract load(path: str)
Load cache from the local file system.
- Parameters:
path –
Returns:
- abstract query(key: Hashable)
Behavior of how to read data under key in cache. Used in get and get_batch
- Parameters:
key –
Returns:
- static size_in_bound(current_size, in_data_size, limit_size)
Check whether the size is still in-bound with new data added into cache
- Parameters:
current_size – current size of cache
in_data_size – size of new data
limit_size – current size limit (no greater than). If negative then no size limit is enforced.
- Returns:
bool. If the size is still in-bound with new data loaded into the cache.
- property size_limit
- write_batch(keys: List[Hashable], batch: V)
Write a batch of data to the cache.
- Parameters:
keys – list of keys corresponding to individual data points in the batch.
batch – batch data to cache.
Returns:
- write_to_cache(key: Hashable, value: V)
Write the data (value) to the given address (key) in the cache
- Parameters:
key – any hashable that points the data to the address in the cache
value – value of the data to cache
Returns:
torch_staintools.cache.tensor_cache module
- class torch_staintools.cache.tensor_cache.TensorCache(size_limit: int, device: device | None = None)
Bases:
Cache
[Dict
[Hashable
,Tensor
],Tensor
]An implementation of Cache specifically for tensor using a built-in dict.
For now, it is used to store stain matrices directly on CPU or GPU memory since stain matrices are typically small (e.g., 2x3 for mapping between H&E and RGB).
Size of concentrations, however, are proportionally to number of pixels x num_stains, therefore it might be better to be cached on the local file system.
- __size_limit: int
- _abc_impl = <_abc._abc_data object>
- _dump_helper(path: str)
Dump the dict to the local file system.
Note: A copy of the dict will be created, with all stored tensors copied to CPU. Dumped tensors are all CPU tensors.
- Parameters:
path – file path to dump.
Returns:
- _new_cache() Dict
Implementation of creating new cache - built-in dict.
- Returns:
A new empty dict.
- static _to_device(data_cache: Dict[Hashable, Tensor], device: device, dict_inplace: bool = True)
Helper function to move all cached tensors to the specified device
- Parameters:
data_cache – the dict to operate on.
device – target device. Note if a tensor is already on the target device, tensor.to(device) will be a no-op.
dict_inplace – whether to move the tensors inplace of the same dict, or create a new dict to store moved tensors.
- Returns:
the original (dict_inplace=True) or the new dict (dict_inplace=False) to store the moved tensors.
- _write_to_cache_helper(key, value: Tensor)
Write the value into the key in cache. Will be moved to the specified device (GPU/CPU) during the procedure.
- Parameters:
key – key to write
value – value to write
Returns:
- classmethod build(*, size_limit: int = -1, device: device | None = None, path: str | None = None)
Factory builder.
- Parameters:
size_limit – limit of the cache size by number of entries (no greater than number of keys). Negative value means no limit will be enforced.
device – which device (CPU or GPUs) to store the tensor. If None then by default it will be set as torch.device(‘cpu’).
path – If specified, previously dumped cache file will be loaded from the path.
Returns:
- data_cache: Dict[Hashable, Tensor]
- device: device
- is_cached(key)
whether the key already stores a value.
- Parameters:
key – key to query
- Returns:
bool of whether the corresponding key already stores a value.
- load(path: str)
Load cache from the local file system.
Keys will be updated. Cached data already in memory will be overwritten if the same key existing in the dumped cache file to load from. Cached data that do not exist in the dumped cache file (by key) will not be affected.
- Parameters:
path – file path to the local cache file to load.
Returns:
- query(key) Tensor
Implementation of abstract method: query
Read from dict directly
- Parameters:
key –
- Returns:
queried output
- Raises:
KeyError. –
- to(device: device)
Move the cache to the specified device. Simulate torch.nn.Module.to and torch.Tensor.to.
The dict itself will be reused but the corresponding tensors stored in the dict might be copied to the target device if they are not already on the target device.
- Parameters:
device – Target device
- Returns:
self.
- static validate_value_type(value: Tensor | ndarray)
Helper function to validate the input.
Must be a torch.Tensor. If it is a numpy ndarray, it will be converted to tensor.
- Parameters:
value – value to validate
- Returns:
torch.Tensor.
- Raises:
AssertionError if the output is not a torch.Tensor –