Source code for boltz_data.mol._mol._bzmol

"""Base BZMol class for molecular structures."""

from functools import cached_property
from typing import Annotated, Any

import numpy as np
import numpy.typing as npt
from pydantic import BaseModel, Field
from rdkit import Chem

from boltz_data.pydantic import NDArray, Shape, ValidationContext


def _default_atom_resolved(arg: dict[str, Any]) -> NDArray[np.bool]:
    if "atom_coordinates" in arg:
        return np.array(True)  # noqa: FBT003
    return np.array(False)  # noqa: FBT003


[docs] class BZMol(BaseModel): """ Base representation of a molecular structure with atoms and bonds. This object can represent both single molecules and multiple disconnected molecules, in a similar manner to an RDKit `Mol` object. Unlike a `Mol` object, information is stored by property, rather than by atom. For example, the `atom_name` property is an array of shape `(n_atoms,)` containing the names of all atoms. """
[docs] def __init__(self, /, **data: Any) -> None: self.__pydantic_validator__.validate_python(data, self_instance=self, context=ValidationContext())
atom_element: Annotated[NDArray[np.uint8], Shape("atom")] """Atomic numbers for each atom, stored as int array.""" atom_name: Annotated[NDArray[np.str_], Shape("atom")] = Field(default=np.array("")) """The atom names (e.g., 'CA', 'N', 'C', 'O' for proteins).""" atom_charge: Annotated[NDArray[np.int8], Shape("atom")] = Field(default=np.array(0)) """Formal charges for each atom, stored as a mapping from atom index to charge.""" bond_atoms: Annotated[NDArray[np.uint32], Shape("bond", 2)] = Field(default=np.array(0)) """Array of shape (n_bonds, 2) with atom indices for each bond.""" bond_order: Annotated[NDArray[np.uint8], Shape("bond")] = Field(default=np.array(1)) """Bond orders (1=single, 2=double, 3=triple), stored as int array.""" atom_resolved: Annotated[NDArray[np.bool], Shape("atom")] = Field(default_factory=_default_atom_resolved) """Optional boolean array of shape (n_atoms,) indicating which atoms have valid coordinates.""" atom_coordinates: Annotated[NDArray[np.float32], Shape("atom", 3)] = Field(default=np.array(0.0)) """Optional array of shape (n_atoms, 3) with xyz coordinates for each atom.""" model_config = {"frozen": True, "extra": "forbid", "validate_default": True} @property def num_atoms(self) -> int: """Total number of atoms in the molecule.""" return len(self.atom_name) @property def num_bonds(self) -> int: """Total number of bonds in the molecule.""" return len(self.bond_atoms) @cached_property def bond_length(self) -> npt.NDArray[np.float32]: """Calculate bond lengths if coordinates are available.""" diffs = self.atom_coordinates[self.bond_atoms[:, 0]] - self.atom_coordinates[self.bond_atoms[:, 1]] return np.linalg.norm(diffs, axis=1) # type: ignore[no-any-return] @cached_property def bond_resolved(self) -> npt.NDArray[np.bool]: return self.atom_resolved[self.bond_atoms].all(axis=-1) # type: ignore[return-value]
[docs] def to_dict(self) -> dict[str, Any]: return { key: value.tolist() if isinstance(value, np.ndarray) else value for key, value in self.model_dump().items() }
@cached_property def atom_adjacency_matrix(self) -> npt.NDArray[np.bool]: """Get the atom-atom adjacency matrix based on the bonds.""" adjacency_matrix = np.zeros((self.num_atoms, self.num_atoms), dtype=np.bool_) for atom1, atom2 in self.bond_atoms: adjacency_matrix[atom1, atom2] = 1 adjacency_matrix[atom2, atom1] = 1 return adjacency_matrix @cached_property def atom_adjacency_list(self) -> list[list[int]]: """Get the atom-atom adjacency list based on the bonds.""" adjacency_list: list[list[int]] = [[] for _ in range(self.num_atoms)] for atom1, atom2 in self.bond_atoms: adjacency_list[atom1].append(atom2) adjacency_list[atom2].append(atom1) return adjacency_list @cached_property def angle_atoms(self) -> NDArray[np.uint16]: """Get array of shape (n_angles, 3) with atom indices for each angle.""" angles = [] for atom2 in range(self.num_atoms): neighbors = self.atom_adjacency_list[atom2] for i in range(len(neighbors)): for j in range(i + 1, len(neighbors)): atom1 = neighbors[i] atom3 = neighbors[j] if atom1 < atom3: angles.append((atom1, atom2, atom3)) else: angles.append((atom3, atom2, atom1)) return np.array(angles, dtype=np.uint16) @cached_property def num_angles(self) -> int: return len(self.angle_atoms) @cached_property def angle_resolved(self) -> NDArray[np.bool]: """Get boolean array of shape (n_angles,) indicating which angles have valid coordinates.""" return self.atom_resolved[self.angle_atoms].all(axis=-1) # type: ignore[return-value] @cached_property def rings(self) -> list[list[int]]: from rdkit import Chem # noqa: PLC0415 from boltz_data.mol._to import rdmol_from_bzmol # noqa: PLC0415 rdmol = rdmol_from_bzmol(self) def sort_cycle(cycle: list[int]) -> list[int]: index = cycle.index(min(cycle)) cycle = cycle[index:] + cycle[:index] if cycle[1] > cycle[-1]: cycle = [cycle[0], *cycle[-1:0:-1]] return cycle return [sort_cycle(list(i)) for i in Chem.rdmolops.GetSSSR(rdmol)] @cached_property def aromatic_rings(self) -> list[list[int]]: from rdkit import Chem # noqa: PLC0415 # Huckel rule: (4n + 2) π-electrons aromatic_rings = [] for ring in self.rings: if len(ring) > 8: continue atom_pi_electrons = self.atom_num_pi_electrons[ring] if not (atom_pi_electrons > 0).all(): continue for i, atom_idx in enumerate(ring): has_exocyclic_o_or_n = False atom = self._rdmol.GetAtomWithIdx(atom_idx) for bond in atom.GetBonds(): if bond.GetBondType() == Chem.BondType.DOUBLE and bond.GetOtherAtom(atom).GetAtomicNum() in [7, 8]: has_exocyclic_o_or_n = True if has_exocyclic_o_or_n: atom_pi_electrons[i] -= 1 if (int(atom_pi_electrons.sum()) - 2) % 4 == 0: aromatic_rings.append(ring) return aromatic_rings @cached_property def atom_is_aromatic(self) -> NDArray[np.bool]: aromatic_atoms = np.zeros(self.num_atoms, dtype=bool) for ring in self.aromatic_rings: aromatic_atoms[ring] = True return aromatic_atoms @cached_property def bond_is_aromatic(self) -> NDArray[np.bool]: return self.atom_is_aromatic[self.bond_atoms].all(axis=-1) # type: ignore[return-value] @cached_property def _rdmol(self) -> Chem.RWMol: from boltz_data.mol._to import rdmol_from_bzmol # noqa: PLC0415 return rdmol_from_bzmol(self) @cached_property def atom_num_pi_electrons(self) -> NDArray[np.uint8]: from rdkit import Chem # noqa: PLC0415 def _get_num_pi_electrons(atom: Chem.Atom, /) -> int: match atom.GetHybridization(): case Chem.HybridizationType.SP2: if atom.GetAtomicNum() in NUM_ELECTRONS: num_electrons = NUM_ELECTRONS[atom.GetAtomicNum()] num_nbrs = atom.GetTotalDegree() num_electrons -= num_nbrs + (3 - num_nbrs) * 2 return num_electrons return 0 case _: return 0 return np.array([_get_num_pi_electrons(atom) for atom in self._rdmol.GetAtoms()], dtype=np.uint8) # type: ignore[no-untyped-call]
NUM_ELECTRONS = {6: 4, 7: 5, 8: 6, 16: 6}