Source code for boltz_data.mol._to._to_definition

from collections.abc import Iterable, Mapping
from typing import Literal

import numpy as np

from boltz_data.ccd import ChemicalComponent, get_builtin_chemical_component_dictionary
from boltz_data.definition import (
    BranchedPolymerDefinition,
    ChainDefinition,
    DNADefinition,
    EntityDefinition,
    LigandCCDDefinition,
    ProteinDefinition,
    RNADefinition,
    StructureDefinition,
)
from boltz_data.mol._mol import BZBioMol
from boltz_data.sequence import sequence_from_residue_names


[docs] def structure_from_bzmol( bzmol: BZBioMol, /, chemical_component_dictionary: Mapping[str, ChemicalComponent] | None = None ) -> StructureDefinition: """Convert a BZBioMol to a structure definition.""" entities: list[EntityDefinition] = [] chains: dict[str, ChainDefinition] = {} chemical_component_dictionary = chemical_component_dictionary or get_builtin_chemical_component_dictionary() for chain_idx, (chain_id, chain_description) in enumerate( zip( bzmol.chain_id, # type: ignore[arg-type] bzmol.chain_description if bzmol.chain_description is not None else [None] * bzmol.num_chains, strict=True, ) ): residue_mask = bzmol.residue_chain == chain_idx residue_names = bzmol.residue_name[residue_mask] residue_numbers = bzmol.residue_number[residue_mask] if bzmol.residue_number is not None else None num_residues = np.sum(residue_mask) entity: EntityDefinition if num_residues == 1: residue_name = residue_names[0] entity = LigandCCDDefinition(type="ligand_ccd", comp_id=str(residue_name), description=chain_description) else: polymer_type = infer_polymer_type_from_residue_names( residue_names, chemical_component_dictionary=chemical_component_dictionary ) match polymer_type: case "protein": entity = ProteinDefinition( type="protein", sequence=sequence_from_residue_names( residue_names.tolist(), polymer_type="protein", nonstandard_handling="parentheses" ), description=chain_description, ) case "dna": entity = DNADefinition( type="dna", sequence=sequence_from_residue_names( residue_names.tolist(), polymer_type="dna", nonstandard_handling="parentheses" ), description=chain_description, ) case "rna": entity = RNADefinition( type="rna", sequence=sequence_from_residue_names( residue_names.tolist(), polymer_type="rna", nonstandard_handling="parentheses" ), description=chain_description, ) case _: entity = BranchedPolymerDefinition( type="branched_polymer", comp_ids=residue_names.tolist(), bonds=[], description=chain_description, ) try: entity_idx = entities.index(entity) except ValueError: entity_idx = len(entities) entities.append(entity) chains[chain_id] = ChainDefinition( entity_idx=entity_idx, residue_numbers=residue_numbers.tolist() if residue_numbers is not None else None, ) return StructureDefinition(entities=entities, chains=chains, bonds=None)
def infer_polymer_type_from_residue_names( residue_names: Iterable[str], /, chemical_component_dictionary: Mapping[str, ChemicalComponent] | None = None ) -> Literal["protein", "rna", "dna"] | None: chemical_component_dictionary = chemical_component_dictionary or get_builtin_chemical_component_dictionary() num_protein = 0 num_dna = 0 num_rna = 0 for residue_name in residue_names: component = chemical_component_dictionary[residue_name] match component.type: case "L-PEPTIDE LINKING": num_protein += 1 case "DNA LINKING": num_dna += 1 case "RNA LINKING": num_rna += 1 if num_protein >= 0.5: return "protein" if num_dna >= 0.5: return "dna" if num_rna >= 0.5: return "rna" return None