from collections.abc import Iterable, Mapping
from typing import Literal
import numpy as np
from boltz_data.ccd import ChemicalComponent, get_builtin_chemical_component_dictionary
from boltz_data.definition import (
BranchedPolymerDefinition,
ChainDefinition,
DNADefinition,
EntityDefinition,
LigandCCDDefinition,
ProteinDefinition,
RNADefinition,
StructureDefinition,
)
from boltz_data.mol._mol import BZBioMol
from boltz_data.sequence import sequence_from_residue_names
[docs]
def structure_from_bzmol(
bzmol: BZBioMol, /, chemical_component_dictionary: Mapping[str, ChemicalComponent] | None = None
) -> StructureDefinition:
"""Convert a BZBioMol to a structure definition."""
entities: list[EntityDefinition] = []
chains: dict[str, ChainDefinition] = {}
chemical_component_dictionary = chemical_component_dictionary or get_builtin_chemical_component_dictionary()
for chain_idx, (chain_id, chain_description) in enumerate(
zip(
bzmol.chain_id, # type: ignore[arg-type]
bzmol.chain_description if bzmol.chain_description is not None else [None] * bzmol.num_chains,
strict=True,
)
):
residue_mask = bzmol.residue_chain == chain_idx
residue_names = bzmol.residue_name[residue_mask]
residue_numbers = bzmol.residue_number[residue_mask] if bzmol.residue_number is not None else None
num_residues = np.sum(residue_mask)
entity: EntityDefinition
if num_residues == 1:
residue_name = residue_names[0]
entity = LigandCCDDefinition(type="ligand_ccd", comp_id=str(residue_name), description=chain_description)
else:
polymer_type = infer_polymer_type_from_residue_names(
residue_names, chemical_component_dictionary=chemical_component_dictionary
)
match polymer_type:
case "protein":
entity = ProteinDefinition(
type="protein",
sequence=sequence_from_residue_names(
residue_names.tolist(), polymer_type="protein", nonstandard_handling="parentheses"
),
description=chain_description,
)
case "dna":
entity = DNADefinition(
type="dna",
sequence=sequence_from_residue_names(
residue_names.tolist(), polymer_type="dna", nonstandard_handling="parentheses"
),
description=chain_description,
)
case "rna":
entity = RNADefinition(
type="rna",
sequence=sequence_from_residue_names(
residue_names.tolist(), polymer_type="rna", nonstandard_handling="parentheses"
),
description=chain_description,
)
case _:
entity = BranchedPolymerDefinition(
type="branched_polymer",
comp_ids=residue_names.tolist(),
bonds=[],
description=chain_description,
)
try:
entity_idx = entities.index(entity)
except ValueError:
entity_idx = len(entities)
entities.append(entity)
chains[chain_id] = ChainDefinition(
entity_idx=entity_idx,
residue_numbers=residue_numbers.tolist() if residue_numbers is not None else None,
)
return StructureDefinition(entities=entities, chains=chains, bonds=None)
def infer_polymer_type_from_residue_names(
residue_names: Iterable[str], /, chemical_component_dictionary: Mapping[str, ChemicalComponent] | None = None
) -> Literal["protein", "rna", "dna"] | None:
chemical_component_dictionary = chemical_component_dictionary or get_builtin_chemical_component_dictionary()
num_protein = 0
num_dna = 0
num_rna = 0
for residue_name in residue_names:
component = chemical_component_dictionary[residue_name]
match component.type:
case "L-PEPTIDE LINKING":
num_protein += 1
case "DNA LINKING":
num_dna += 1
case "RNA LINKING":
num_rna += 1
if num_protein >= 0.5:
return "protein"
if num_dna >= 0.5:
return "dna"
if num_rna >= 0.5:
return "rna"
return None