Source code for boltz_data.ccd._compressed_dictionary

import gzip
import pickle
from collections.abc import Iterator, Mapping
from functools import lru_cache
from pathlib import Path
from typing import Any

from smart_open import open as smart_open  # type: ignore[import-untyped]

from ._compress import compress_chemical_component
from ._decompress import decompress_chemical_component
from ._models import ChemicalComponent


[docs] class CompressedChemicalComponentDictionary(Mapping[str, ChemicalComponent]): """ A dictionary of chemical components stored in compressed binary format. This class provides a mapping interface to access chemical components while keeping them compressed in memory. Components are decompressed on-demand and cached using an LRU cache to balance memory usage and performance. """ ccd: dict[str, bytes]
[docs] def __init__(self, ccd: dict[str, bytes], /) -> None: """ Initialize the compressed dictionary. Args: ccd: Dictionary mapping component IDs to compressed bytes. """ self.ccd = ccd # Bind the cached method to this instance self._get_cached = lru_cache(maxsize=256)(self._get_cached_impl)
def _get_cached_impl(self, comp_id: str) -> ChemicalComponent: return decompress_chemical_component(self.ccd[comp_id]) def __getstate__(self) -> dict[str, Any]: return {"ccd": self.ccd} def __setstate__(self, newstate: dict[str, Any]) -> None: self.__dict__.update(newstate) self._get_cached = lru_cache(maxsize=256)(self._get_cached_impl) def __getitem__(self, comp_id: str) -> ChemicalComponent: return self._get_cached(comp_id) def __iter__(self) -> Iterator[str]: return iter(self.ccd) def __len__(self) -> int: return len(self.ccd)
[docs] @classmethod def from_file(cls, path: str | Path, /) -> "CompressedChemicalComponentDictionary": """ Load a compressed chemical component dictionary from a file. Supports both plain pickle (.pkl) and gzipped pickle (.pkl.gz) formats. Args: path: Path to the file containing the compressed dictionary. Returns: A CompressedChemicalComponentDictionary instance. """ path = Path(path) with path.open("rb") as f: data = f.read() # Check if file is gzipped by looking at magic number or extension if data[:2] == b"\x1f\x8b" or path.suffix == ".gz": data = gzip.decompress(data) ccd = pickle.loads(data) # noqa: S301 return cls(ccd)
[docs] def to_file(self, path: str | Path, /) -> None: """ Save the compressed chemical component dictionary to a file. If the file extension is .gz, the output will be gzip compressed. Args: path: Path to save the dictionary to. """ path = Path(path) data = pickle.dumps(self.ccd) with smart_open(str(path), "wb") as f: f.write(data)
[docs] @classmethod def from_ccd(cls, ccd: dict[str, ChemicalComponent], /) -> "CompressedChemicalComponentDictionary": """ Create a compressed dictionary from a dictionary of chemical components. Args: ccd: Dictionary mapping component IDs to ChemicalComponent objects. Returns: A CompressedChemicalComponentDictionary instance. """ compressed_ccd = {comp_id: compress_chemical_component(component) for comp_id, component in ccd.items()} return cls(compressed_ccd)