Source code for boltz_data.mol._op._subset

from typing import TypeVar

import numpy as np
import numpy.typing as npt

from boltz_data.mol._mol import BZBioMol

TDType = TypeVar("TDType", bound=np.generic)


def _remap_indices(mask: npt.NDArray[TDType]) -> npt.NDArray[np.int32]:
    remap = np.repeat(-1, mask.shape[0])
    remap[mask] = np.arange(np.sum(mask))
    return remap


[docs] def subset_bzmol(bzmol: BZBioMol, *, chain_ids: list[str] | None = None) -> BZBioMol: """Extract a subset of a BZBioMol by selecting specific chains.""" chain_mask = ( np.array([chain_id in chain_ids for chain_id in bzmol.chain_id], dtype=bool) if chain_ids is not None else np.ones(bzmol.num_chains, dtype=bool) ) residue_mask = chain_mask[bzmol.residue_chain] atom_mask = residue_mask[bzmol.atom_residue] atom_remapped = _remap_indices(atom_mask) residue_remapped = _remap_indices(residue_mask) chain_remapped = _remap_indices(chain_mask) return BZBioMol( atom_name=bzmol.atom_name[atom_mask], atom_element=bzmol.atom_element[atom_mask], atom_charge=bzmol.atom_charge[atom_mask], atom_residue=residue_remapped[bzmol.atom_residue[atom_mask]], atom_coordinates=bzmol.atom_coordinates[atom_mask] if bzmol.atom_coordinates is not None else None, atom_resolved=bzmol.atom_resolved[atom_mask] if bzmol.atom_resolved is not None else None, residue_name=bzmol.residue_name[residue_mask], residue_number=bzmol.residue_number[residue_mask] if bzmol.residue_number is not None else None, residue_chain=chain_remapped[bzmol.residue_chain[residue_mask]], bond_atoms=atom_remapped[bzmol.bond_atoms[np.all(atom_mask[bzmol.bond_atoms], axis=1)]], bond_order=bzmol.bond_order[np.all(atom_mask[bzmol.bond_atoms], axis=1)], chain_id=bzmol.chain_id[chain_mask], chain_description=bzmol.chain_description[chain_mask] if bzmol.chain_description is not None else None, )