"""Base BZMol class for molecular structures."""
from functools import cached_property
from typing import Annotated, Any
import numpy as np
import numpy.typing as npt
from pydantic import BaseModel, Field
from rdkit import Chem
from boltz_data.pydantic import NDArray, Shape, ValidationContext
def _default_atom_resolved(arg: dict[str, Any]) -> NDArray[np.bool]:
if "atom_coordinates" in arg:
return np.array(True) # noqa: FBT003
return np.array(False) # noqa: FBT003
[docs]
class BZMol(BaseModel):
"""
Base representation of a molecular structure with atoms and bonds.
This object can represent both single molecules and multiple disconnected molecules,
in a similar manner to an RDKit `Mol` object.
Unlike a `Mol` object, information is stored by property, rather than by atom. For example,
the `atom_name` property is an array of shape `(n_atoms,)` containing the names of all atoms.
"""
[docs]
def __init__(self, /, **data: Any) -> None:
self.__pydantic_validator__.validate_python(data, self_instance=self, context=ValidationContext())
atom_element: Annotated[NDArray[np.uint8], Shape("atom")]
"""Atomic numbers for each atom, stored as int array."""
atom_name: Annotated[NDArray[np.str_], Shape("atom")] = Field(default=np.array(""))
"""The atom names (e.g., 'CA', 'N', 'C', 'O' for proteins)."""
atom_charge: Annotated[NDArray[np.int8], Shape("atom")] = Field(default=np.array(0))
"""Formal charges for each atom, stored as a mapping from atom index to charge."""
bond_atoms: Annotated[NDArray[np.uint32], Shape("bond", 2)] = Field(default=np.array(0))
"""Array of shape (n_bonds, 2) with atom indices for each bond."""
bond_order: Annotated[NDArray[np.uint8], Shape("bond")] = Field(default=np.array(1))
"""Bond orders (1=single, 2=double, 3=triple), stored as int array."""
atom_resolved: Annotated[NDArray[np.bool], Shape("atom")] = Field(default_factory=_default_atom_resolved)
"""Optional boolean array of shape (n_atoms,) indicating which atoms have valid coordinates."""
atom_coordinates: Annotated[NDArray[np.float32], Shape("atom", 3)] = Field(default=np.array(0.0))
"""Optional array of shape (n_atoms, 3) with xyz coordinates for each atom."""
model_config = {"frozen": True, "extra": "forbid", "validate_default": True}
@property
def num_atoms(self) -> int:
"""Total number of atoms in the molecule."""
return len(self.atom_name)
@property
def num_bonds(self) -> int:
"""Total number of bonds in the molecule."""
return len(self.bond_atoms)
@cached_property
def bond_length(self) -> npt.NDArray[np.float32]:
"""Calculate bond lengths if coordinates are available."""
diffs = self.atom_coordinates[self.bond_atoms[:, 0]] - self.atom_coordinates[self.bond_atoms[:, 1]]
return np.linalg.norm(diffs, axis=1) # type: ignore[no-any-return]
@cached_property
def bond_resolved(self) -> npt.NDArray[np.bool]:
return self.atom_resolved[self.bond_atoms].all(axis=-1) # type: ignore[return-value]
[docs]
def to_dict(self) -> dict[str, Any]:
return {
key: value.tolist() if isinstance(value, np.ndarray) else value for key, value in self.model_dump().items()
}
@cached_property
def atom_adjacency_matrix(self) -> npt.NDArray[np.bool]:
"""Get the atom-atom adjacency matrix based on the bonds."""
adjacency_matrix = np.zeros((self.num_atoms, self.num_atoms), dtype=np.bool_)
for atom1, atom2 in self.bond_atoms:
adjacency_matrix[atom1, atom2] = 1
adjacency_matrix[atom2, atom1] = 1
return adjacency_matrix
@cached_property
def atom_adjacency_list(self) -> list[list[int]]:
"""Get the atom-atom adjacency list based on the bonds."""
adjacency_list: list[list[int]] = [[] for _ in range(self.num_atoms)]
for atom1, atom2 in self.bond_atoms:
adjacency_list[atom1].append(atom2)
adjacency_list[atom2].append(atom1)
return adjacency_list
@cached_property
def angle_atoms(self) -> NDArray[np.uint16]:
"""Get array of shape (n_angles, 3) with atom indices for each angle."""
angles = []
for atom2 in range(self.num_atoms):
neighbors = self.atom_adjacency_list[atom2]
for i in range(len(neighbors)):
for j in range(i + 1, len(neighbors)):
atom1 = neighbors[i]
atom3 = neighbors[j]
if atom1 < atom3:
angles.append((atom1, atom2, atom3))
else:
angles.append((atom3, atom2, atom1))
return np.array(angles, dtype=np.uint16)
@cached_property
def num_angles(self) -> int:
return len(self.angle_atoms)
@cached_property
def angle_resolved(self) -> NDArray[np.bool]:
"""Get boolean array of shape (n_angles,) indicating which angles have valid coordinates."""
return self.atom_resolved[self.angle_atoms].all(axis=-1) # type: ignore[return-value]
@cached_property
def rings(self) -> list[list[int]]:
from rdkit import Chem # noqa: PLC0415
from boltz_data.mol._to import rdmol_from_bzmol # noqa: PLC0415
rdmol = rdmol_from_bzmol(self)
def sort_cycle(cycle: list[int]) -> list[int]:
index = cycle.index(min(cycle))
cycle = cycle[index:] + cycle[:index]
if cycle[1] > cycle[-1]:
cycle = [cycle[0], *cycle[-1:0:-1]]
return cycle
return [sort_cycle(list(i)) for i in Chem.rdmolops.GetSSSR(rdmol)]
@cached_property
def aromatic_rings(self) -> list[list[int]]:
from rdkit import Chem # noqa: PLC0415
# Huckel rule: (4n + 2) π-electrons
aromatic_rings = []
for ring in self.rings:
if len(ring) > 8:
continue
atom_pi_electrons = self.atom_num_pi_electrons[ring]
if not (atom_pi_electrons > 0).all():
continue
for i, atom_idx in enumerate(ring):
has_exocyclic_o_or_n = False
atom = self._rdmol.GetAtomWithIdx(atom_idx)
for bond in atom.GetBonds():
if bond.GetBondType() == Chem.BondType.DOUBLE and bond.GetOtherAtom(atom).GetAtomicNum() in [7, 8]:
has_exocyclic_o_or_n = True
if has_exocyclic_o_or_n:
atom_pi_electrons[i] -= 1
if (int(atom_pi_electrons.sum()) - 2) % 4 == 0:
aromatic_rings.append(ring)
return aromatic_rings
@cached_property
def atom_is_aromatic(self) -> NDArray[np.bool]:
aromatic_atoms = np.zeros(self.num_atoms, dtype=bool)
for ring in self.aromatic_rings:
aromatic_atoms[ring] = True
return aromatic_atoms
@cached_property
def bond_is_aromatic(self) -> NDArray[np.bool]:
return self.atom_is_aromatic[self.bond_atoms].all(axis=-1) # type: ignore[return-value]
@cached_property
def _rdmol(self) -> Chem.RWMol:
from boltz_data.mol._to import rdmol_from_bzmol # noqa: PLC0415
return rdmol_from_bzmol(self)
@cached_property
def atom_num_pi_electrons(self) -> NDArray[np.uint8]:
from rdkit import Chem # noqa: PLC0415
def _get_num_pi_electrons(atom: Chem.Atom, /) -> int:
match atom.GetHybridization():
case Chem.HybridizationType.SP2:
if atom.GetAtomicNum() in NUM_ELECTRONS:
num_electrons = NUM_ELECTRONS[atom.GetAtomicNum()]
num_nbrs = atom.GetTotalDegree()
num_electrons -= num_nbrs + (3 - num_nbrs) * 2
return num_electrons
return 0
case _:
return 0
return np.array([_get_num_pi_electrons(atom) for atom in self._rdmol.GetAtoms()], dtype=np.uint8) # type: ignore[no-untyped-call]
NUM_ELECTRONS = {6: 4, 7: 5, 8: 6, 16: 6}