Source code for boltz_data.ccd._compressed_dictionary
import gzip
import pickle
from collections.abc import Iterator, Mapping
from functools import lru_cache
from pathlib import Path
from typing import Any
from smart_open import open as smart_open # type: ignore[import-untyped]
from ._compress import compress_chemical_component
from ._decompress import decompress_chemical_component
from ._models import ChemicalComponent
[docs]
class CompressedChemicalComponentDictionary(Mapping[str, ChemicalComponent]):
"""
A dictionary of chemical components stored in compressed binary format.
This class provides a mapping interface to access chemical components while keeping
them compressed in memory. Components are decompressed on-demand and cached using
an LRU cache to balance memory usage and performance.
"""
ccd: dict[str, bytes]
[docs]
def __init__(self, ccd: dict[str, bytes], /) -> None:
"""
Initialize the compressed dictionary.
Args:
ccd: Dictionary mapping component IDs to compressed bytes.
"""
self.ccd = ccd
# Bind the cached method to this instance
self._get_cached = lru_cache(maxsize=256)(self._get_cached_impl)
def _get_cached_impl(self, comp_id: str) -> ChemicalComponent:
return decompress_chemical_component(self.ccd[comp_id])
def __getstate__(self) -> dict[str, Any]:
return {"ccd": self.ccd}
def __setstate__(self, newstate: dict[str, Any]) -> None:
self.__dict__.update(newstate)
self._get_cached = lru_cache(maxsize=256)(self._get_cached_impl)
def __getitem__(self, comp_id: str) -> ChemicalComponent:
return self._get_cached(comp_id)
def __iter__(self) -> Iterator[str]:
return iter(self.ccd)
def __len__(self) -> int:
return len(self.ccd)
[docs]
@classmethod
def from_file(cls, path: str | Path, /) -> "CompressedChemicalComponentDictionary":
"""
Load a compressed chemical component dictionary from a file.
Supports both plain pickle (.pkl) and gzipped pickle (.pkl.gz) formats.
Args:
path: Path to the file containing the compressed dictionary.
Returns:
A CompressedChemicalComponentDictionary instance.
"""
path = Path(path)
with path.open("rb") as f:
data = f.read()
# Check if file is gzipped by looking at magic number or extension
if data[:2] == b"\x1f\x8b" or path.suffix == ".gz":
data = gzip.decompress(data)
ccd = pickle.loads(data) # noqa: S301
return cls(ccd)
[docs]
def to_file(self, path: str | Path, /) -> None:
"""
Save the compressed chemical component dictionary to a file.
If the file extension is .gz, the output will be gzip compressed.
Args:
path: Path to save the dictionary to.
"""
path = Path(path)
data = pickle.dumps(self.ccd)
with smart_open(str(path), "wb") as f:
f.write(data)
[docs]
@classmethod
def from_ccd(cls, ccd: dict[str, ChemicalComponent], /) -> "CompressedChemicalComponentDictionary":
"""
Create a compressed dictionary from a dictionary of chemical components.
Args:
ccd: Dictionary mapping component IDs to ChemicalComponent objects.
Returns:
A CompressedChemicalComponentDictionary instance.
"""
compressed_ccd = {comp_id: compress_chemical_component(component) for comp_id, component in ccd.items()}
return cls(compressed_ccd)