from typing import Literal
import numpy as np
from rdkit import Chem
from boltz_data.ccd import ChemicalComponent, get_builtin_chemical_component_dictionary
from boltz_data.definition import InternalBond
from boltz_data.mol._mol import BZBioMol, BZMol
from boltz_data.sequence import BACKBONE_ATOMS
PERIODIC_TABLE = Chem.GetPeriodicTable()
def polymer_bzmol_from_chemical_components(
*,
polymer_type: Literal["protein", "dna", "rna"],
chemical_components: list[ChemicalComponent],
chain_id: str,
residue_numbers: list[int] | None = None,
description: str | None = None,
bonds: list[InternalBond] | None = None,
) -> BZBioMol:
if polymer_type == "protein":
missing_links = [
i
for i in range(len(chemical_components) - 1)
if not any(b.residue_index_1 == i and b.residue_index_2 == i + 1 for b in (bonds or []))
and BACKBONE_ATOMS["protein"]["previous"] in chemical_components[i].atoms
and BACKBONE_ATOMS["protein"]["next"] in chemical_components[i + 1].atoms
]
bonds = (bonds or []) + [
InternalBond(
residue_index_1=i,
atom_name_1=BACKBONE_ATOMS["protein"]["previous"],
residue_index_2=i + 1,
atom_name_2=BACKBONE_ATOMS["protein"]["next"],
leaving_atom_name_1="OXT",
bond_order=1,
)
for i in missing_links
]
return bzmol_from_chemical_components(
chemical_components=chemical_components,
chain_id=chain_id,
bonds=bonds,
residue_numbers=residue_numbers,
description=description,
)
if polymer_type in {"rna", "dna"}:
missing_links = [
i
for i in range(len(chemical_components) - 1)
if not any(b.residue_index_1 == i and b.residue_index_2 == i + 1 for b in (bonds or []))
and BACKBONE_ATOMS[polymer_type]["previous"] in chemical_components[i].atoms
and BACKBONE_ATOMS[polymer_type]["next"] in chemical_components[i + 1].atoms
]
bonds = (bonds or []) + [
InternalBond(
residue_index_1=i,
atom_name_1=BACKBONE_ATOMS[polymer_type]["previous"],
residue_index_2=i + 1,
atom_name_2=BACKBONE_ATOMS[polymer_type]["next"],
leaving_atom_name_2="OP3",
bond_order=1,
)
for i in missing_links
]
return bzmol_from_chemical_components(
chemical_components=chemical_components,
chain_id=chain_id,
bonds=bonds,
residue_numbers=residue_numbers,
description=description,
)
msg = f"Unknown polymer type: {polymer_type}"
raise ValueError(msg)
[docs]
def bzmol_from_chemical_component(
chemical_component: ChemicalComponent | str,
) -> BZMol:
if isinstance(chemical_component, str):
chemical_component = get_builtin_chemical_component_dictionary()[chemical_component]
atom_names: list[str] = []
atom_elements: list[int] = []
atom_charges: list[int] = []
bond_atoms: list[tuple[int, int]] = []
bond_orders: list[int] = []
atom_name_to_atom_idx: dict[str, int] = {}
for atom in chemical_component.atoms.values():
atom_names.append(atom.atom_id)
atom_name_to_atom_idx[atom.atom_id] = len(atom_name_to_atom_idx)
if atom.element.capitalize() == "D":
atom_elements.append(1)
else:
atom_elements.append(PERIODIC_TABLE.GetAtomicNumber(atom.element.capitalize()))
atom_charges.append(atom.charge)
valid_bonds = [
bond
for bond in chemical_component.bonds.values()
if bond.atom_id_1 in atom_name_to_atom_idx and bond.atom_id_2 in atom_name_to_atom_idx
]
bond_atoms.extend(
[
(
atom_name_to_atom_idx[bond.atom_id_1],
atom_name_to_atom_idx[bond.atom_id_2],
)
for bond in valid_bonds
]
)
bond_orders.extend([bond.order for bond in valid_bonds])
return BZMol(
atom_name=np.array(atom_names, dtype=object),
atom_element=np.array(atom_elements, dtype=np.uint8),
atom_charge=np.array(atom_charges, dtype=np.int8),
bond_atoms=np.array(bond_atoms, dtype=np.uint32) if bond_atoms else np.zeros((0, 2), dtype=np.uint32),
bond_order=np.array(bond_orders, dtype=np.uint8),
)
[docs]
def bzmol_from_chemical_components( # noqa: C901
*,
chemical_components: list[ChemicalComponent],
chain_id: str,
bonds: list[InternalBond] | None = None,
residue_numbers: list[int] | None = None,
description: str | None = None,
) -> BZBioMol:
atom_names: list[str] = []
atom_elements: list[int] = []
atom_charges: list[int] = []
atom_residues: list[int] = []
residue_names: list[str] = []
bond_atoms: list[tuple[int, int]] = []
bond_orders: list[int] = []
atom_offset = 0
comp_idx_and_atom_name_to_atom_idx: dict[tuple[int, str], int] = {}
for i, chemical_component in enumerate(chemical_components):
leaving_atoms: set[str] = set()
for bond in bonds or []:
if bond.residue_index_1 == i and bond.leaving_atom_name_1 is not None:
leaving_atoms.add(bond.leaving_atom_name_1)
if bond.residue_index_2 == i and bond.leaving_atom_name_2 is not None:
leaving_atoms.add(bond.leaving_atom_name_2)
num_atoms = 0
for atom in chemical_component.atoms.values():
if atom.atom_id in leaving_atoms:
continue
atom_names.append(atom.atom_id)
comp_idx_and_atom_name_to_atom_idx[(i, atom.atom_id)] = len(comp_idx_and_atom_name_to_atom_idx)
if atom.element.capitalize() == "D":
atom_elements.append(1)
else:
atom_elements.append(PERIODIC_TABLE.GetAtomicNumber(atom.element.capitalize()))
atom_charges.append(atom.charge)
num_atoms += 1
atom_residues.extend([i] * num_atoms)
residue_names.append(chemical_component.comp_id)
valid_bonds = [
bond
for bond in chemical_component.bonds.values()
if (i, bond.atom_id_1) in comp_idx_and_atom_name_to_atom_idx
and (i, bond.atom_id_2) in comp_idx_and_atom_name_to_atom_idx
]
bond_atoms.extend(
[
(
comp_idx_and_atom_name_to_atom_idx[(i, bond.atom_id_1)],
comp_idx_and_atom_name_to_atom_idx[(i, bond.atom_id_2)],
)
for bond in valid_bonds
]
)
bond_orders.extend([bond.order for bond in valid_bonds])
atom_offset += num_atoms
# polymer bonds
for bond in bonds or []:
if (bond.residue_index_1, bond.atom_name_1) not in comp_idx_and_atom_name_to_atom_idx:
msg = (
f"Cannot form polymer bond: Atom {bond.atom_name_1} not found "
f"in {chemical_components[bond.residue_index_1].comp_id}{bond.residue_index_1 + 1}"
)
raise ValueError(msg)
if (bond.residue_index_2, bond.atom_name_2) not in comp_idx_and_atom_name_to_atom_idx:
msg = (
f"Cannot form polymer bond: Atom {bond.atom_name_2} not found "
f"in {chemical_components[bond.residue_index_2].comp_id}{bond.residue_index_2 + 1}"
)
raise ValueError(msg)
bond_atoms.append(
(
comp_idx_and_atom_name_to_atom_idx[(bond.residue_index_1, bond.atom_name_1)],
comp_idx_and_atom_name_to_atom_idx[(bond.residue_index_2, bond.atom_name_2)],
)
)
bond_orders.append(bond.bond_order if bond.bond_order is not None else 1)
return BZBioMol(
atom_name=np.array(atom_names, dtype=object),
atom_element=np.array(atom_elements, dtype=np.uint8),
atom_charge=np.array(atom_charges, dtype=np.int8),
atom_residue=np.array(atom_residues, dtype=np.uint32),
residue_name=np.array(residue_names, dtype=object),
residue_number=np.array(residue_numbers, dtype=np.int32) if residue_numbers is not None else None,
residue_chain=np.zeros(len(residue_names), dtype=np.uint16),
bond_atoms=np.array(bond_atoms, dtype=np.uint32) if bond_atoms else np.zeros((0, 2), dtype=np.uint32),
bond_order=np.array(bond_orders, dtype=np.uint8),
chain_id=np.array([chain_id], dtype=object) if chain_id is not None else None,
chain_description=np.array([description], dtype=object) if description is not None else None,
)