Source code for boltz_data.mol._op._subset
from typing import TypeVar
import numpy as np
import numpy.typing as npt
from boltz_data.mol._mol import BZBioMol
TDType = TypeVar("TDType", bound=np.generic)
def _remap_indices(mask: npt.NDArray[TDType]) -> npt.NDArray[np.int32]:
remap = np.repeat(-1, mask.shape[0])
remap[mask] = np.arange(np.sum(mask))
return remap
[docs]
def subset_bzmol(bzmol: BZBioMol, *, chain_ids: list[str] | None = None) -> BZBioMol:
"""Extract a subset of a BZBioMol by selecting specific chains."""
chain_mask = (
np.array([chain_id in chain_ids for chain_id in bzmol.chain_id], dtype=bool)
if chain_ids is not None
else np.ones(bzmol.num_chains, dtype=bool)
)
residue_mask = chain_mask[bzmol.residue_chain]
atom_mask = residue_mask[bzmol.atom_residue]
atom_remapped = _remap_indices(atom_mask)
residue_remapped = _remap_indices(residue_mask)
chain_remapped = _remap_indices(chain_mask)
return BZBioMol(
atom_name=bzmol.atom_name[atom_mask],
atom_element=bzmol.atom_element[atom_mask],
atom_charge=bzmol.atom_charge[atom_mask],
atom_residue=residue_remapped[bzmol.atom_residue[atom_mask]],
atom_coordinates=bzmol.atom_coordinates[atom_mask] if bzmol.atom_coordinates is not None else None,
atom_resolved=bzmol.atom_resolved[atom_mask] if bzmol.atom_resolved is not None else None,
residue_name=bzmol.residue_name[residue_mask],
residue_number=bzmol.residue_number[residue_mask] if bzmol.residue_number is not None else None,
residue_chain=chain_remapped[bzmol.residue_chain[residue_mask]],
bond_atoms=atom_remapped[bzmol.bond_atoms[np.all(atom_mask[bzmol.bond_atoms], axis=1)]],
bond_order=bzmol.bond_order[np.all(atom_mask[bzmol.bond_atoms], axis=1)],
chain_id=bzmol.chain_id[chain_mask],
chain_description=bzmol.chain_description[chain_mask] if bzmol.chain_description is not None else None,
)