Source code for boltz_data.mol._from._from_ccd

from typing import Literal

import numpy as np
from rdkit import Chem

from boltz_data.ccd import ChemicalComponent, get_builtin_chemical_component_dictionary
from boltz_data.definition import InternalBond
from boltz_data.mol._mol import BZBioMol, BZMol
from boltz_data.sequence import BACKBONE_ATOMS

PERIODIC_TABLE = Chem.GetPeriodicTable()


def polymer_bzmol_from_chemical_components(
    *,
    polymer_type: Literal["protein", "dna", "rna"],
    chemical_components: list[ChemicalComponent],
    chain_id: str,
    residue_numbers: list[int] | None = None,
    description: str | None = None,
    bonds: list[InternalBond] | None = None,
) -> BZBioMol:
    if polymer_type == "protein":
        missing_links = [
            i
            for i in range(len(chemical_components) - 1)
            if not any(b.residue_index_1 == i and b.residue_index_2 == i + 1 for b in (bonds or []))
            and BACKBONE_ATOMS["protein"]["previous"] in chemical_components[i].atoms
            and BACKBONE_ATOMS["protein"]["next"] in chemical_components[i + 1].atoms
        ]
        bonds = (bonds or []) + [
            InternalBond(
                residue_index_1=i,
                atom_name_1=BACKBONE_ATOMS["protein"]["previous"],
                residue_index_2=i + 1,
                atom_name_2=BACKBONE_ATOMS["protein"]["next"],
                leaving_atom_name_1="OXT",
                bond_order=1,
            )
            for i in missing_links
        ]
        return bzmol_from_chemical_components(
            chemical_components=chemical_components,
            chain_id=chain_id,
            bonds=bonds,
            residue_numbers=residue_numbers,
            description=description,
        )
    if polymer_type in {"rna", "dna"}:
        missing_links = [
            i
            for i in range(len(chemical_components) - 1)
            if not any(b.residue_index_1 == i and b.residue_index_2 == i + 1 for b in (bonds or []))
            and BACKBONE_ATOMS[polymer_type]["previous"] in chemical_components[i].atoms
            and BACKBONE_ATOMS[polymer_type]["next"] in chemical_components[i + 1].atoms
        ]
        bonds = (bonds or []) + [
            InternalBond(
                residue_index_1=i,
                atom_name_1=BACKBONE_ATOMS[polymer_type]["previous"],
                residue_index_2=i + 1,
                atom_name_2=BACKBONE_ATOMS[polymer_type]["next"],
                leaving_atom_name_2="OP3",
                bond_order=1,
            )
            for i in missing_links
        ]
        return bzmol_from_chemical_components(
            chemical_components=chemical_components,
            chain_id=chain_id,
            bonds=bonds,
            residue_numbers=residue_numbers,
            description=description,
        )
    msg = f"Unknown polymer type: {polymer_type}"
    raise ValueError(msg)


[docs] def bzmol_from_chemical_component( chemical_component: ChemicalComponent | str, ) -> BZMol: if isinstance(chemical_component, str): chemical_component = get_builtin_chemical_component_dictionary()[chemical_component] atom_names: list[str] = [] atom_elements: list[int] = [] atom_charges: list[int] = [] bond_atoms: list[tuple[int, int]] = [] bond_orders: list[int] = [] atom_name_to_atom_idx: dict[str, int] = {} for atom in chemical_component.atoms.values(): atom_names.append(atom.atom_id) atom_name_to_atom_idx[atom.atom_id] = len(atom_name_to_atom_idx) if atom.element.capitalize() == "D": atom_elements.append(1) else: atom_elements.append(PERIODIC_TABLE.GetAtomicNumber(atom.element.capitalize())) atom_charges.append(atom.charge) valid_bonds = [ bond for bond in chemical_component.bonds.values() if bond.atom_id_1 in atom_name_to_atom_idx and bond.atom_id_2 in atom_name_to_atom_idx ] bond_atoms.extend( [ ( atom_name_to_atom_idx[bond.atom_id_1], atom_name_to_atom_idx[bond.atom_id_2], ) for bond in valid_bonds ] ) bond_orders.extend([bond.order for bond in valid_bonds]) return BZMol( atom_name=np.array(atom_names, dtype=object), atom_element=np.array(atom_elements, dtype=np.uint8), atom_charge=np.array(atom_charges, dtype=np.int8), bond_atoms=np.array(bond_atoms, dtype=np.uint32) if bond_atoms else np.zeros((0, 2), dtype=np.uint32), bond_order=np.array(bond_orders, dtype=np.uint8), )
[docs] def bzmol_from_chemical_components( # noqa: C901 *, chemical_components: list[ChemicalComponent], chain_id: str, bonds: list[InternalBond] | None = None, residue_numbers: list[int] | None = None, description: str | None = None, ) -> BZBioMol: atom_names: list[str] = [] atom_elements: list[int] = [] atom_charges: list[int] = [] atom_residues: list[int] = [] residue_names: list[str] = [] bond_atoms: list[tuple[int, int]] = [] bond_orders: list[int] = [] atom_offset = 0 comp_idx_and_atom_name_to_atom_idx: dict[tuple[int, str], int] = {} for i, chemical_component in enumerate(chemical_components): leaving_atoms: set[str] = set() for bond in bonds or []: if bond.residue_index_1 == i and bond.leaving_atom_name_1 is not None: leaving_atoms.add(bond.leaving_atom_name_1) if bond.residue_index_2 == i and bond.leaving_atom_name_2 is not None: leaving_atoms.add(bond.leaving_atom_name_2) num_atoms = 0 for atom in chemical_component.atoms.values(): if atom.atom_id in leaving_atoms: continue atom_names.append(atom.atom_id) comp_idx_and_atom_name_to_atom_idx[(i, atom.atom_id)] = len(comp_idx_and_atom_name_to_atom_idx) if atom.element.capitalize() == "D": atom_elements.append(1) else: atom_elements.append(PERIODIC_TABLE.GetAtomicNumber(atom.element.capitalize())) atom_charges.append(atom.charge) num_atoms += 1 atom_residues.extend([i] * num_atoms) residue_names.append(chemical_component.comp_id) valid_bonds = [ bond for bond in chemical_component.bonds.values() if (i, bond.atom_id_1) in comp_idx_and_atom_name_to_atom_idx and (i, bond.atom_id_2) in comp_idx_and_atom_name_to_atom_idx ] bond_atoms.extend( [ ( comp_idx_and_atom_name_to_atom_idx[(i, bond.atom_id_1)], comp_idx_and_atom_name_to_atom_idx[(i, bond.atom_id_2)], ) for bond in valid_bonds ] ) bond_orders.extend([bond.order for bond in valid_bonds]) atom_offset += num_atoms # polymer bonds for bond in bonds or []: if (bond.residue_index_1, bond.atom_name_1) not in comp_idx_and_atom_name_to_atom_idx: msg = ( f"Cannot form polymer bond: Atom {bond.atom_name_1} not found " f"in {chemical_components[bond.residue_index_1].comp_id}{bond.residue_index_1 + 1}" ) raise ValueError(msg) if (bond.residue_index_2, bond.atom_name_2) not in comp_idx_and_atom_name_to_atom_idx: msg = ( f"Cannot form polymer bond: Atom {bond.atom_name_2} not found " f"in {chemical_components[bond.residue_index_2].comp_id}{bond.residue_index_2 + 1}" ) raise ValueError(msg) bond_atoms.append( ( comp_idx_and_atom_name_to_atom_idx[(bond.residue_index_1, bond.atom_name_1)], comp_idx_and_atom_name_to_atom_idx[(bond.residue_index_2, bond.atom_name_2)], ) ) bond_orders.append(bond.bond_order if bond.bond_order is not None else 1) return BZBioMol( atom_name=np.array(atom_names, dtype=object), atom_element=np.array(atom_elements, dtype=np.uint8), atom_charge=np.array(atom_charges, dtype=np.int8), atom_residue=np.array(atom_residues, dtype=np.uint32), residue_name=np.array(residue_names, dtype=object), residue_number=np.array(residue_numbers, dtype=np.int32) if residue_numbers is not None else None, residue_chain=np.zeros(len(residue_names), dtype=np.uint16), bond_atoms=np.array(bond_atoms, dtype=np.uint32) if bond_atoms else np.zeros((0, 2), dtype=np.uint32), bond_order=np.array(bond_orders, dtype=np.uint8), chain_id=np.array([chain_id], dtype=object) if chain_id is not None else None, chain_description=np.array([description], dtype=object) if description is not None else None, )