Source code for boltz_data.mol._op._concat

"""Functions for concatenating multiple BZBioMol objects."""

from typing import Any, cast, overload

import numpy as np

from boltz_data.mol._mol import BZBioMol, BZMol


@overload
def concat_bzmols(*bzmols: BZBioMol) -> BZBioMol: ...


@overload
def concat_bzmols(*bzmols: BZBioMol | BZMol) -> BZMol: ...


[docs] def concat_bzmols(*bzmols: BZMol | BZBioMol) -> BZMol | BZBioMol: """ Concatenate multiple BZBioMol objects into a single BZBioMol. Args: *bzmols: Variable number of BZBioMol objects to concatenate. Returns: A single BZBioMol containing all atoms, residues, and bonds from input BZBioMols. """ if len(bzmols) == 0: msg = "At least one BZMol must be provided" raise ValueError(msg) if len(bzmols) == 1: return bzmols[0] kwargs: dict[str, Any] = {} output_type = BZBioMol if all(isinstance(mol, BZBioMol) for mol in bzmols) else BZMol if output_type == BZBioMol: bzmols = cast("tuple[BZBioMol, ...]", bzmols) kwargs["residue_name"] = np.concatenate([mol.residue_name for mol in bzmols]) kwargs["residue_number"] = ( np.concatenate([mol.residue_number for mol in bzmols]) if any(mol.residue_number is not None for mol in bzmols) else None ) kwargs["chain_id"] = np.concatenate([mol.chain_id for mol in bzmols]) kwargs["chain_description"] = ( np.concatenate([mol.chain_description for mol in bzmols if mol.chain_description is not None]) if any(mol.chain_description is not None for mol in bzmols) else None ) residue_offset = np.pad(np.cumsum([mol.num_residues for mol in bzmols], dtype=int)[:-1], (1, 0), "constant") chain_offset = np.pad(np.cumsum([mol.num_chains for mol in bzmols], dtype=int)[:-1], (1, 0), "constant") kwargs["atom_residue"] = np.concatenate( [mol.atom_residue + residue_offset for residue_offset, mol in zip(residue_offset, bzmols, strict=True)] ) kwargs["residue_chain"] = np.concatenate( [mol.residue_chain + chain_offset for chain_offset, mol in zip(chain_offset, bzmols, strict=True)], ) kwargs["atom_name"] = np.concatenate([mol.atom_name for mol in bzmols]) kwargs["atom_element"] = np.concatenate([mol.atom_element for mol in bzmols]) kwargs["atom_charge"] = np.concatenate([mol.atom_charge for mol in bzmols]) kwargs["bond_order"] = np.concatenate([mol.bond_order for mol in bzmols]) atom_offset = np.pad(np.cumsum([mol.num_atoms for mol in bzmols], dtype=int)[:-1], (1, 0), "constant") kwargs["bond_atoms"] = np.concatenate( [mol.bond_atoms + atom_offset for atom_offset, mol in zip(atom_offset, bzmols, strict=True)] ) kwargs["atom_coordinates"] = ( np.concatenate([mol.atom_coordinates for mol in bzmols]) if any(mol.atom_coordinates is not None for mol in bzmols) else None ) kwargs["atom_resolved"] = ( np.concatenate([mol.atom_resolved for mol in bzmols]) if any(mol.atom_resolved is not None for mol in bzmols) else None ) return output_type(**kwargs)