Source code for boltz_data.mol._mol._bzbiomol
"""BZBioMol class for biomolecular structures with residues and chains."""
from functools import cached_property
from typing import Annotated
import numpy as np
import numpy.typing as npt
from pydantic import Field
from boltz_data._utils import mask_to_slice
from boltz_data.pydantic import NDArray, Shape
from ._bzmol import BZMol
[docs]
class BZBioMol(BZMol):
"""
Biomolecular structure with residues and chains.
Extends BZMol to include residue and chain information
for proteins, DNA, RNA, and other biomolecules.
"""
atom_b_factor: Annotated[NDArray[np.float32], Shape("atom")] | None = Field(default=None)
"""Optional array of shape (n_atoms,) with B-factors for each atom."""
atom_residue: Annotated[NDArray[np.uint32], Shape("atom")]
residue_name: Annotated[NDArray[np.str_], Shape("residue")]
"""Residue names (e.g., 'ALA', 'GLY' for proteins, 'A', 'G' for nucleic acids)."""
residue_number: Annotated[NDArray[np.int32], Shape("residue")] | None = Field(default=None)
residue_chain: Annotated[NDArray[np.uint16], Shape("residue")]
chain_id: Annotated[NDArray[np.str_], Shape("chain")]
"""Optional list of chain identifiers (e.g., 'A', 'B', 'C') for each chain."""
chain_description: Annotated[NDArray[np.str_], Shape("chain")] | None = Field(default=None)
@property
def num_residues(self) -> int:
"""Total number of residues in the molecule."""
return len(self.residue_name)
@property
def num_chains(self) -> int:
"""Total number of chains in the molecule."""
return len(self.chain_id)
@cached_property
def residue_ordinal(self) -> npt.NDArray[np.uint16]:
"""0-indexed ordinal number of each residue within its chain."""
unique, inverse = np.unique(self.residue_chain, return_inverse=True)
counts = np.zeros_like(self.residue_chain)
for u in range(len(unique)):
mask = inverse == u
counts[mask] = np.arange(np.sum(mask))
return counts
@cached_property
def residue_atoms(self) -> list[slice | list[int]]:
"""Get atom indices for each residue as slices or lists."""
return [_mask_to_indices_or_slice(self.atom_residue == residue_idx) for residue_idx in range(self.num_residues)]
@cached_property
def chain_residues(self) -> list[slice | list[int]]:
"""Get residue indices for each chain as slices or lists."""
return [_mask_to_indices_or_slice(self.residue_chain == chain_idx) for chain_idx in range(self.num_chains)]
@cached_property
def chain_atoms(self) -> list[slice | list[int]]:
"""Get atom indices for each chain as slices or lists."""
atom_chain = self.residue_chain[self.atom_residue]
return [_mask_to_indices_or_slice(atom_chain == chain_idx) for chain_idx in range(self.num_chains)]
@cached_property
def residue_any_resolved(self) -> npt.NDArray[np.bool]:
"""Get a boolean array of shape (n_residues,) indicating which residues have any resolved atoms."""
return (
np.bincount(
self.atom_residue[self.atom_resolved],
minlength=self.num_residues,
)
> 0
)
def _mask_to_indices_or_slice(mask: npt.NDArray[np.bool]) -> slice | list[int]:
if mask.any():
# Use slice if contiguous, otherwise use list
true_indices = np.where(mask)[0]
if len(true_indices) == 1 or np.all(np.diff(true_indices) == 1):
return mask_to_slice(mask)
return true_indices.tolist() # type: ignore[no-any-return]
return slice(0, 0, None) # Empty slice for residues with no atoms