torch_staintools.cache package

Submodules

torch_staintools.cache.base module

class torch_staintools.cache.base.Cache(size_limit: int)

Bases: ABC, Generic[C, V]

A simple abstraction of cache.

__size_limit: int
_abc_impl = <_abc._abc_data object>
abstract _dump_helper(path: str)

To implement: dump the cached data to the local file system.

Parameters:

path – output filename

Returns:

abstract _new_cache()
abstract _write_to_cache_helper(key: Hashable, value: V)

Write the data (value) to the given address (key) in the cache

Parameters:
  • key – any hashable that points the data to the address in the cache

  • value – value of the data to cache

Returns:

abstract classmethod build(*args, **kwargs)
data_cache: C
dump(path: str, force_overwrite: bool = False)

Dump the cached data to the local file system.

Parameters:
  • path – output filename

  • force_overwrite – whether to force overwriting the existing file on path

Returns:

get(key: Hashable, func: Callable | None, *func_args, **func_kwargs)

Get the data cached under key.

If the corresponding data of key is not yet cached, it will be computed by the func(*func_args, **func_kwargs) and the results will be cached if the remaining size is sufficient.

Parameters:
  • key – the address of the data in cache

  • func – callable to evaluate the new data to cache if not yet cached under key

  • *func_args – positional arguments of func

  • **func_kwargs – keyword arguments of func

Returns:

get_batch(keys: List[Hashable], func: Callable | None, *func_args, **func_kwargs) List[V]

Batchified get

The method assumes that the func callable would generate a whole batch of data each time. Might be useful if batchified processing is much faster than individually process all inputs (e.g., cuda tensors processed by nn.Module)

It is a hit only if all keys are cached.

Parameters:
  • keys – list of keys corresponding to the batch input.

  • func – function to generate the data if the corresponding entry is not cached.

  • *func_args – positional args for the func.

  • **func_kwargs – keyword args for the func.

Returns:

List of queried results.

abstract is_cached(key: Hashable)

whether the key already stores a value.

Parameters:

key – key to query

Returns:

bool of whether the corresponding key already stores a value.

abstract load(path: str)

Load cache from the local file system.

Parameters:

path

Returns:

abstract query(key: Hashable)

Behavior of how to read data under key in cache. Used in get and get_batch

Parameters:

key

Returns:

static size_in_bound(current_size, in_data_size, limit_size)

Check whether the size is still in-bound with new data added into cache

Parameters:
  • current_size – current size of cache

  • in_data_size – size of new data

  • limit_size – current size limit (no greater than). If negative then no size limit is enforced.

Returns:

bool. If the size is still in-bound with new data loaded into the cache.

property size_limit
write_batch(keys: List[Hashable], batch: V)

Write a batch of data to the cache.

Parameters:
  • keys – list of keys corresponding to individual data points in the batch.

  • batch – batch data to cache.

Returns:

write_to_cache(key: Hashable, value: V)

Write the data (value) to the given address (key) in the cache

Parameters:
  • key – any hashable that points the data to the address in the cache

  • value – value of the data to cache

Returns:

torch_staintools.cache.tensor_cache module

class torch_staintools.cache.tensor_cache.TensorCache(size_limit: int, device: device | None = None)

Bases: Cache[Dict[Hashable, Tensor], Tensor]

An implementation of Cache specifically for tensor using a built-in dict.

For now, it is used to store stain matrices directly on CPU or GPU memory since stain matrices are typically small (e.g., 2x3 for mapping between H&E and RGB).

Size of concentrations, however, are proportionally to number of pixels x num_stains, therefore it might be better to be cached on the local file system.

__size_limit: int
_abc_impl = <_abc._abc_data object>
_dump_helper(path: str)

Dump the dict to the local file system.

Note: A copy of the dict will be created, with all stored tensors copied to CPU. Dumped tensors are all CPU tensors.

Parameters:

path – file path to dump.

Returns:

_new_cache() Dict

Implementation of creating new cache - built-in dict.

Returns:

A new empty dict.

static _to_device(data_cache: Dict[Hashable, Tensor], device: device, dict_inplace: bool = True)

Helper function to move all cached tensors to the specified device

Parameters:
  • data_cache – the dict to operate on.

  • device – target device. Note if a tensor is already on the target device, tensor.to(device) will be a no-op.

  • dict_inplace – whether to move the tensors inplace of the same dict, or create a new dict to store moved tensors.

Returns:

the original (dict_inplace=True) or the new dict (dict_inplace=False) to store the moved tensors.

_write_to_cache_helper(key, value: Tensor)

Write the value into the key in cache. Will be moved to the specified device (GPU/CPU) during the procedure.

Parameters:
  • key – key to write

  • value – value to write

Returns:

classmethod build(*, size_limit: int = -1, device: device | None = None, path: str | None = None)

Factory builder.

Parameters:
  • size_limit – limit of the cache size by number of entries (no greater than number of keys). Negative value means no limit will be enforced.

  • device – which device (CPU or GPUs) to store the tensor. If None then by default it will be set as torch.device(‘cpu’).

  • path – If specified, previously dumped cache file will be loaded from the path.

Returns:

data_cache: Dict[Hashable, Tensor]
device: device
is_cached(key)

whether the key already stores a value.

Parameters:

key – key to query

Returns:

bool of whether the corresponding key already stores a value.

load(path: str)

Load cache from the local file system.

Keys will be updated. Cached data already in memory will be overwritten if the same key existing in the dumped cache file to load from. Cached data that do not exist in the dumped cache file (by key) will not be affected.

Parameters:

path – file path to the local cache file to load.

Returns:

query(key) Tensor

Implementation of abstract method: query

Read from dict directly

Parameters:

key

Returns:

queried output

Raises:

KeyError.

to(device: device)

Move the cache to the specified device. Simulate torch.nn.Module.to and torch.Tensor.to.

The dict itself will be reused but the corresponding tensors stored in the dict might be copied to the target device if they are not already on the target device.

Parameters:

device – Target device

Returns:

self.

static validate_value_type(value: Tensor | ndarray)

Helper function to validate the input.

Must be a torch.Tensor. If it is a numpy ndarray, it will be converted to tensor.

Parameters:

value – value to validate

Returns:

torch.Tensor.

Raises:

AssertionError if the output is not a torch.Tensor

Module contents