from __future__ import annotations
import atexit
import functools
import hashlib
import logging
import numbers
import time
from typing import Any, Callable
import traceback
import numpy as np
from .store_keeper import StoreKeeper
[docs]
class umbrella_cache:
_store_keeper = StoreKeeper.get_instance()
_memory_store = _store_keeper.memory_store
_disk_store = _store_keeper.disk_store
_store_keeper_started = False
atexit.register(_store_keeper.close)
def __init__(self,
topic: str = "general",
bypass_memory_store: bool = False,
custom_handlers: dict[type[Any], Callable] | None = None,
evaluation_time_threshold: float = 0.01
):
"""A hybrid disk and in-memory cache decorator compatible with numpy
This hybrid cache hashes the input parameters to a function
to compute the cache key. If a persistent disk store is configured,
data are saved to disk in a background thread
(see :class:`.StoreKeeper`).
The `umbrella_cache` decorator can be used safely with regular
functions and methods defined in classes.
Parameters
----------
topic: str
A general topic (alphanumeric characters and dashes allowed)
under which the cached values are stored.
bypass_memory_store: bool
Set this to true to disable storing cached values in memory.
If no disk store (see :class:`.StoreKeeper`) is defined,
then data are never cached.
custom_handlers:
Custom handlers for hashing objects not handled by
:func:`update_hash`. ``custom_handlers`` is a dictionary
where a key is a class (or optionally the name
of a class as a string) and a value is a callable that
translates a class instance to something that can be
processed by :func:`update_hash`.
evaluation_time_threshold:
If the time the method takes for computation is shorter than
the the threshold value, then it will not be cached.
Notes
-----
The behavior of `umbrella_cache` is defined by the global
:class:`.StoreKeeper` instance. The ``StoreKeeper`` is
a background thread that manages data in memory and on disk.
If you are using other decorators with this decorator, please
make sure to apply the `umbrella_cache` first (first line before
method definition). This wrapper uses name, doc, and filename of the
method to identify it. If another wrapper does not implement
a unique `__doc__` and is applied to multiple methods, then
`umbrella_cache` might return values of the wrong method.
"""
# topic must be a valid directory name
topic = "".join(
[t for t in topic if t in "abcdefghijklmnopqrstuvwxyz-0123456789"])
self.topic = topic or "general"
self.use_memory_store = not bypass_memory_store
self.custom_handlers = custom_handlers
self.evaluation_time_threshold = evaluation_time_threshold
self.logger = logging.getLogger(__name__)
def __call__(self, func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not self.use_memory_store and not self._disk_store:
# shortcut when nothing is cached
return func(*args, **kwargs)
# Make sure the StoreKeeper thread is running
if not umbrella_cache._store_keeper_started:
if not umbrella_cache._store_keeper.is_alive():
umbrella_cache._store_keeper.start()
umbrella_cache._store_keeper_started = True
t0 = time.perf_counter()
ref = compute_hash_for_cache(func, args, kwargs,
custom_handlers=self.custom_handlers)
key = f"{self.topic}/{ref[:3]}/{ref[3:6]}/{ref[6:]}"
if self.use_memory_store and key in self._memory_store:
return self._memory_store[key]
elif key in self._disk_store:
try:
value = self._disk_store[key]
except BaseException:
# Maybe the data got deleted. Only the `else` clause
# returns, so the value is computed and returned as if
# it wasn't there.
# Remove key from the disk store index.
self._disk_store.index.remove(key)
self.logger.warning(f"Could not fetch {key} from disk "
f"store: {traceback.format_exc()}")
pass
else:
if self.use_memory_store:
self._memory_store[key] = value
return value
value = func(*args, **kwargs)
if time.perf_counter() - t0 < self.evaluation_time_threshold:
pass # no caching
elif self.use_memory_store:
self._memory_store[key] = value
elif self._disk_store:
# Only write to DiskStore directly, if the memory store
# is disabled. Normally, the StoreKeeper does this in the
# background.
self._disk_store[key] = value
return value
return wrapper
[docs]
def compute_hash_for_cache(
func: Callable,
args: list | tuple,
kwargs: dict,
custom_handlers: dict[type[Any], Callable] | None = None,
) -> str:
"""Compute the hash for caching the function return value"""
the_hash = hashlib.md5()
# hash arguments
update_hash(the_hash, args, custom_handlers=custom_handlers)
# hash keyword arguments
update_hash(the_hash, kwargs, custom_handlers=custom_handlers)
# metadata
update_hash(the_hash, func.__name__)
update_hash(the_hash, func.__doc__)
update_hash(the_hash, func.__code__.co_filename)
return the_hash.hexdigest()
[docs]
def update_hash(the_hash,
arg,
custom_handlers: dict[type[Any], Callable] | None = None
) -> None:
"""Update a hashing object with a Python object
The argument can be a numpy array, a string, or a list/tuple
of objects that are convertable to strings.
"""
if isinstance(arg, numbers.Number):
the_hash.update(f"{arg}".encode('utf-8'))
elif isinstance(arg, str):
the_hash.update(arg.encode('utf-8'))
elif arg is None:
the_hash.update(b"none")
elif isinstance(arg, bytes):
the_hash.update(arg)
elif isinstance(arg, np.ndarray):
the_hash.update(arg.tobytes())
elif isinstance(arg, (list, tuple)):
for a in arg:
update_hash(the_hash, a,
custom_handlers=custom_handlers)
elif isinstance(arg, dict):
update_hash(the_hash, sorted(arg.items()),
custom_handlers=custom_handlers)
else:
if custom_handlers:
for cls, handler in custom_handlers.items():
if isinstance(cls, str):
if arg.__class__.__name__ == cls:
# Handler identifier is the class name of the argument
update_hash(the_hash, handler(arg))
return # no further checks necessary
else:
# The class could be a subclass of one of the handlers
found_base = False
for bcl in arg.__class__.__bases__:
if bcl.__name__ == cls:
found_base = True
update_hash(the_hash, handler(arg))
if found_base:
return # no further checks necessary
elif isinstance(arg, cls):
# Handler identifier is the class of the argument
update_hash(the_hash, handler(arg))
return # no further checks necessary
# Final option are classes that define `__getstate__`
if hasattr(arg, 'for_json'):
try:
update_hash(the_hash, arg.for_json())
except BaseException:
pass
else:
return # no further checks necessary
raise ValueError(f"No rule to hash object of type {type(arg)}")