Source code for deepretro.utils.cache

"""
Explicit in-memory cache primitives for expensive library operations.

This module provides a small process-local ``CacheManager`` backed by in-memory
Python dictionaries. Each cache entry stores a value, an optional expiry
deadline measured with ``time.monotonic()``, and an optional tag. A secondary
``tag -> set[key]`` index makes group eviction efficient.

The cache is explicit rather than global: callers instantiate ``CacheManager``
objects and pass them where needed. The implementation does not use locks, so a
single ``CacheManager`` instance is not thread-safe.
"""

from __future__ import annotations

import hashlib
import json
import pickle
import sys
import time
from dataclasses import dataclass
from typing import Any

import structlog

logger = structlog.get_logger(__name__)

_MISS = object()

__all__ = [
    "CacheEntry",
    "CacheManager",
    "CacheStats",
    "make_args_hash",
    "make_cache_key",
]


[docs] @dataclass class CacheStats: """ Snapshot of live cache statistics returned by :meth:`CacheManager.stats`. The reported values describe the cache after expired entries have been purged. They are intended for diagnostics and monitoring rather than exact process-memory accounting. Attributes ---------- hits : int Number of successful ``CacheManager.get`` lookups. misses : int Number of failed ``CacheManager.get`` lookups, including expired keys. size_bytes : int Shallow approximation of the live cache footprint in bytes. The estimate includes the top-level entry and tag dictionaries, their keys, and the immediate cached values, but does not traverse referenced objects recursively. num_entries : int Number of live entries remaining after expired values are purged. This reflects the keys that still participate in lookups and tag eviction. """ hits: int misses: int size_bytes: int num_entries: int
[docs] @dataclass class CacheEntry: """ Single in-memory cache entry. Each entry stores a cached payload, an optional expiry deadline measured with ``time.monotonic()``, and an optional tag used for group invalidation. Tags let callers associate multiple cache keys with the same logical input such as one molecule, model configuration, or request family. Attributes ---------- value : Any Cached payload returned by ``CacheManager.get``. expires_at : float | None ``time.monotonic()`` deadline when the key becomes stale. ``None`` means the entry does not expire automatically. tag : str | None Optional group label attached when calling ``cache.set(..., tag=...)``. All keys written with the same tag can be removed together with ``CacheManager.evict_tag``, which is useful when multiple cached values should be invalidated as one group. """ value: Any expires_at: float | None tag: str | None
[docs] def make_args_hash(*args: Any, **kwargs: Any) -> str: """ Generate a deterministic hash of arguments for cache keying. Tries JSON first for common types; falls back to pickle for complex objects. Examples -------- >>> make_args_hash("CCO", az_model="USPTO") '6ad01e27a3a319962ad084787e060ab0fa0e661cc7d3e018e96747b06f7bacf7' """ payload = {"args": args, "kwargs": kwargs} try: raw = json.dumps(payload, sort_keys=True, default=str) return hashlib.sha256(raw.encode("utf-8")).hexdigest() except (TypeError, ValueError): raw = pickle.dumps(payload, protocol=pickle.HIGHEST_PROTOCOL) return hashlib.sha256(raw).hexdigest()
[docs] def make_cache_key(namespace: str, *args: Any, version: int = 1, **kwargs: Any) -> str: """ Build a deterministic cache key for a namespaced operation. Parameters ---------- namespace : str Stable operation name, such as ``"run_az"``. *args : Any Positional arguments that affect the cached result. version : int, optional Cache version. Bump when behavior changes and old entries should be invalidated, by default 1. **kwargs : Any Keyword arguments that affect the cached result. Returns ------- str A deterministic key suitable for ``CacheManager.get`` and ``set``. Examples -------- >>> make_cache_key("run_az", "CCO", az_model="USPTO", version=1) 'v1:run_az:6ad01e27a3a319962ad084787e060ab0fa0e661cc7d3e018e96747b06f7bacf7' """ if not namespace: raise ValueError("namespace must be a non-empty string") args_hash = make_args_hash(*args, **kwargs) return f"v{version}:{namespace}:{args_hash}"
[docs] class CacheManager: """ Process-local in-memory cache manager with tag support and TTL. Each instance owns two in-memory indexes: ``_entries`` maps cache keys to :class:`CacheEntry` objects, and ``_tags`` maps each tag to the set of keys currently carrying that tag. ``get`` removes expired keys lazily, ``evict_tag`` removes every key associated with a tag, and ``stats`` first purges expired values so the reported counts reflect live entries only. The cache is process-local and not thread-safe. Reuse the same ``CacheManager`` instance only when callers intentionally want to share state. Examples -------- >>> cache = CacheManager() >>> key = make_cache_key("call_llm", "CCO", model="gpt-5.4", version=1) >>> miss = object() >>> cache.get(key, default=miss) is miss True >>> cache.set(key, {"molecule": "CCO"}, expire=300, tag="molecule:CCO") >>> cache.get(key) {'molecule': 'CCO'} >>> cache.evict_tag("molecule:CCO") 1 """
[docs] def __init__(self) -> None: """Initialize an empty in-memory cache.""" self._entries: dict[str, CacheEntry] = {} self._tags: dict[str, set[str]] = {} self._hits = 0 self._misses = 0 self._log = logger.bind(component="cache")
[docs] def purge_if_expired(self, key: str) -> bool: """ Remove a key if its expiry deadline has passed. Parameters ---------- key : str Cache key to inspect. Returns ------- bool ``True`` when the key existed and was removed because it had expired, otherwise ``False``. """ entry = self._entries.get(key) if entry is None: return False if entry.expires_at is None or entry.expires_at > time.monotonic(): return False self.delete_key(key) self._log.debug("cache.expired", key=key) return True
[docs] def delete_key(self, key: str) -> bool: """ Remove a key from the cache and tag index if present. Parameters ---------- key : str Cache key to remove. Returns ------- bool ``True`` when an entry was removed, otherwise ``False``. """ entry = self._entries.pop(key, None) if entry is None: return False if entry.tag is not None: tagged_keys = self._tags.get(entry.tag) if tagged_keys is not None: tagged_keys.discard(key) if not tagged_keys: self._tags.pop(entry.tag, None) return True
[docs] def purge_expired_entries(self) -> None: """ Remove every expired entry currently stored in the cache. This is useful before inspecting cache size or exporting diagnostics. """ for key in list(self._entries): self.purge_if_expired(key)
[docs] def estimate_size_bytes(self) -> int: """ Return a shallow approximation of the current in-memory cache size. The estimate includes the top-level dictionaries, keys, tag sets, and the immediate cached values. Referenced objects are not traversed recursively. """ size = sys.getsizeof(self._entries) + sys.getsizeof(self._tags) for key, entry in self._entries.items(): size += sys.getsizeof(key) size += sys.getsizeof(entry) size += sys.getsizeof(entry.value) for tag, keys in self._tags.items(): size += sys.getsizeof(tag) size += sys.getsizeof(keys) return size
[docs] def get(self, key: str, default: Any = _MISS) -> Any: """ Retrieve a value by key. Parameters ---------- key : str Cache key. default : Any, optional Value returned when the key is not cached, by default an internal sentinel object. Returns ------- Any Cached value, or ``default`` if not found. Examples -------- >>> cache = CacheManager() >>> miss = object() >>> cache.get("missing", default=miss) is miss True """ if self.purge_if_expired(key): self._misses += 1 self._log.debug("cache.miss", key=key) return default entry = self._entries.get(key) if entry is None: self._misses += 1 self._log.debug("cache.miss", key=key) return default self._hits += 1 self._log.debug("cache.hit", key=key) return entry.value
[docs] def set( self, key: str, value: Any, *, expire: float | None = None, tag: str | None = None, ) -> None: """ Store a value with optional TTL and tag. Parameters ---------- key : str Cache key. value : Any Value to store. expire : float | None, optional Time-to-live in seconds. None means no expiry. tag : str | None, optional Optional group label for later eviction via ``evict_tag``. Multiple keys may share the same tag. Examples -------- >>> cache = CacheManager() >>> cache.set("demo", {"smiles": "CCO"}, expire=60, tag="molecule:CCO") """ self.delete_key(key) expires_at = None if expire is None else time.monotonic() + expire self._entries[key] = CacheEntry(value=value, expires_at=expires_at, tag=tag) if tag is not None: self._tags.setdefault(tag, set()).add(key) self._log.debug("cache.set", key=key, expire=expire, tag=tag)
[docs] def evict_tag(self, tag: str) -> int: """ Remove all live entries with the given tag. Parameters ---------- tag : str Group label identifying entries to remove. A single tag may be attached to multiple cache keys. Returns ------- int Number of entries evicted. Examples -------- >>> cache = CacheManager() >>> cache.set("a", 1, tag="batch:1") >>> cache.set("b", 2, tag="batch:1") >>> cache.evict_tag("batch:1") 2 """ keys = list(self._tags.get(tag, set())) removed = 0 for key in keys: removed += int(self.delete_key(key)) self._log.info("cache.evict_tag", tag=tag, removed=removed) return removed
[docs] def clear(self) -> None: """Remove all entries from this cache instance.""" self._entries.clear() self._tags.clear() self._log.info("cache.clear")
[docs] def stats(self) -> CacheStats: """ Return cache statistics. Returns ------- CacheStats A snapshot containing hit count, miss count, shallow byte estimate, and live entry count. Examples -------- >>> cache = CacheManager() >>> cache.set("demo", 1) >>> stats = cache.stats() >>> (stats.hits, stats.misses, stats.num_entries) (0, 0, 1) """ self.purge_expired_entries() return CacheStats( hits=self._hits, misses=self._misses, size_bytes=self.estimate_size_bytes(), num_entries=len(self._entries), )
[docs] def close(self) -> None: """Release cache contents for compatibility with previous callers.""" self.clear()