Source code for boltz_data.mol._mol._bzbiomol

"""BZBioMol class for biomolecular structures with residues and chains."""

from functools import cached_property
from typing import Annotated

import numpy as np
import numpy.typing as npt
from pydantic import Field

from boltz_data._utils import mask_to_slice
from boltz_data.pydantic import NDArray, Shape

from ._bzmol import BZMol


[docs] class BZBioMol(BZMol): """ Biomolecular structure with residues and chains. Extends BZMol to include residue and chain information for proteins, DNA, RNA, and other biomolecules. """ atom_b_factor: Annotated[NDArray[np.float32], Shape("atom")] | None = Field(default=None) """Optional array of shape (n_atoms,) with B-factors for each atom.""" atom_residue: Annotated[NDArray[np.uint32], Shape("atom")] residue_name: Annotated[NDArray[np.str_], Shape("residue")] """Residue names (e.g., 'ALA', 'GLY' for proteins, 'A', 'G' for nucleic acids).""" residue_number: Annotated[NDArray[np.int32], Shape("residue")] | None = Field(default=None) residue_chain: Annotated[NDArray[np.uint16], Shape("residue")] chain_id: Annotated[NDArray[np.str_], Shape("chain")] """Optional list of chain identifiers (e.g., 'A', 'B', 'C') for each chain.""" chain_description: Annotated[NDArray[np.str_], Shape("chain")] | None = Field(default=None) @property def num_residues(self) -> int: """Total number of residues in the molecule.""" return len(self.residue_name) @property def num_chains(self) -> int: """Total number of chains in the molecule.""" return len(self.chain_id) @cached_property def residue_ordinal(self) -> npt.NDArray[np.uint16]: """0-indexed ordinal number of each residue within its chain.""" unique, inverse = np.unique(self.residue_chain, return_inverse=True) counts = np.zeros_like(self.residue_chain) for u in range(len(unique)): mask = inverse == u counts[mask] = np.arange(np.sum(mask)) return counts @cached_property def residue_atoms(self) -> list[slice | list[int]]: """Get atom indices for each residue as slices or lists.""" return [_mask_to_indices_or_slice(self.atom_residue == residue_idx) for residue_idx in range(self.num_residues)] @cached_property def chain_residues(self) -> list[slice | list[int]]: """Get residue indices for each chain as slices or lists.""" return [_mask_to_indices_or_slice(self.residue_chain == chain_idx) for chain_idx in range(self.num_chains)] @cached_property def chain_atoms(self) -> list[slice | list[int]]: """Get atom indices for each chain as slices or lists.""" atom_chain = self.residue_chain[self.atom_residue] return [_mask_to_indices_or_slice(atom_chain == chain_idx) for chain_idx in range(self.num_chains)] @cached_property def residue_any_resolved(self) -> npt.NDArray[np.bool]: """Get a boolean array of shape (n_residues,) indicating which residues have any resolved atoms.""" return ( np.bincount( self.atom_residue[self.atom_resolved], minlength=self.num_residues, ) > 0 )
def _mask_to_indices_or_slice(mask: npt.NDArray[np.bool]) -> slice | list[int]: if mask.any(): # Use slice if contiguous, otherwise use list true_indices = np.where(mask)[0] if len(true_indices) == 1 or np.all(np.diff(true_indices) == 1): return mask_to_slice(mask) return true_indices.tolist() # type: ignore[no-any-return] return slice(0, 0, None) # Empty slice for residues with no atoms