Source code for boltz_data.mol._op._concat
"""Functions for concatenating multiple BZBioMol objects."""
from typing import Any, cast, overload
import numpy as np
from boltz_data.mol._mol import BZBioMol, BZMol
@overload
def concat_bzmols(*bzmols: BZBioMol) -> BZBioMol: ...
@overload
def concat_bzmols(*bzmols: BZBioMol | BZMol) -> BZMol: ...
[docs]
def concat_bzmols(*bzmols: BZMol | BZBioMol) -> BZMol | BZBioMol:
"""
Concatenate multiple BZBioMol objects into a single BZBioMol.
Args:
*bzmols: Variable number of BZBioMol objects to concatenate.
Returns:
A single BZBioMol containing all atoms, residues, and bonds from input BZBioMols.
"""
if len(bzmols) == 0:
msg = "At least one BZMol must be provided"
raise ValueError(msg)
if len(bzmols) == 1:
return bzmols[0]
kwargs: dict[str, Any] = {}
output_type = BZBioMol if all(isinstance(mol, BZBioMol) for mol in bzmols) else BZMol
if output_type == BZBioMol:
bzmols = cast("tuple[BZBioMol, ...]", bzmols)
kwargs["residue_name"] = np.concatenate([mol.residue_name for mol in bzmols])
kwargs["residue_number"] = (
np.concatenate([mol.residue_number for mol in bzmols])
if any(mol.residue_number is not None for mol in bzmols)
else None
)
kwargs["chain_id"] = np.concatenate([mol.chain_id for mol in bzmols])
kwargs["chain_description"] = (
np.concatenate([mol.chain_description for mol in bzmols if mol.chain_description is not None])
if any(mol.chain_description is not None for mol in bzmols)
else None
)
residue_offset = np.pad(np.cumsum([mol.num_residues for mol in bzmols], dtype=int)[:-1], (1, 0), "constant")
chain_offset = np.pad(np.cumsum([mol.num_chains for mol in bzmols], dtype=int)[:-1], (1, 0), "constant")
kwargs["atom_residue"] = np.concatenate(
[mol.atom_residue + residue_offset for residue_offset, mol in zip(residue_offset, bzmols, strict=True)]
)
kwargs["residue_chain"] = np.concatenate(
[mol.residue_chain + chain_offset for chain_offset, mol in zip(chain_offset, bzmols, strict=True)],
)
kwargs["atom_name"] = np.concatenate([mol.atom_name for mol in bzmols])
kwargs["atom_element"] = np.concatenate([mol.atom_element for mol in bzmols])
kwargs["atom_charge"] = np.concatenate([mol.atom_charge for mol in bzmols])
kwargs["bond_order"] = np.concatenate([mol.bond_order for mol in bzmols])
atom_offset = np.pad(np.cumsum([mol.num_atoms for mol in bzmols], dtype=int)[:-1], (1, 0), "constant")
kwargs["bond_atoms"] = np.concatenate(
[mol.bond_atoms + atom_offset for atom_offset, mol in zip(atom_offset, bzmols, strict=True)]
)
kwargs["atom_coordinates"] = (
np.concatenate([mol.atom_coordinates for mol in bzmols])
if any(mol.atom_coordinates is not None for mol in bzmols)
else None
)
kwargs["atom_resolved"] = (
np.concatenate([mol.atom_resolved for mol in bzmols])
if any(mol.atom_resolved is not None for mol in bzmols)
else None
)
return output_type(**kwargs)