from collections import defaultdict
from collections.abc import Mapping
import gemmi
from boltz_data.definition import (
BranchedPolymerDefinition,
ChainDefinition,
DNADefinition,
EntityDefinition,
InterChainBond,
InternalBond,
LigandCCDDefinition,
ProteinDefinition,
RNADefinition,
StructureDefinition,
)
from boltz_data.sequence import sequence_from_residue_names
from ._branched import get_branched_polymers_from_mmcif
from ._nonpolymer import get_nonpolymers_from_mmcif
from ._polymer import get_polymers_from_mmcif
from ._utils import clean_string
MMCIF_POLYMER_TYPE_TO_BOLTZ_TYPE = {
"polypeptide(L)": "protein",
"polypeptide(D)": "protein",
"polydeoxyribonucleotide": "dna",
"polyribonucleotide": "rna",
"polydeoxyribonucleotide/polyribonucleotide hybrid": "dna",
}
[docs]
def get_structure_from_mmcif(mmcif: gemmi.cif.Block, /) -> StructureDefinition:
"""
Parse an mmCIF block into a StructureDefinition.
Extracts entities (proteins, DNA, RNA, branched polymers, ligands), chains,
and bond information from the mmCIF format.
Args:
mmcif: The mmCIF block to parse.
Returns:
A StructureDefinition containing all entities, chains, and bonds.
Raises:
ValueError: If no chains are found in the structure or if an unknown polymer type is encountered.
"""
polymers = get_polymers_from_mmcif(mmcif)
branched_polymers = get_branched_polymers_from_mmcif(mmcif)
nonpolymers = get_nonpolymers_from_mmcif(mmcif)
chain_to_internal_bonds = _get_polymer_bonds(mmcif)
chain_to_entity_id = {
chain_id: int(entity_id) for chain_id, entity_id in mmcif.find("_struct_asym.", ["id", "entity_id"])
}
entity_id_to_example_chain = {entity_id: chain_id for chain_id, entity_id in chain_to_entity_id.items()}
entities: dict[int, EntityDefinition] = {}
chain_to_residue_numbers: dict[str, list[int]] = {}
for entity_id, entity_type, description in mmcif.find("_entity.", ["id", "type", "pdbx_description"]):
if entity_type == "water":
continue
description_formatted = clean_string(description) if description != "?" else None
if int(entity_id) in polymers:
polymer = polymers[int(entity_id)]
boltz_type = MMCIF_POLYMER_TYPE_TO_BOLTZ_TYPE[polymer.type]
chain_to_residue_numbers.update(polymer.chain_to_residue_numbers)
internal_bonds = chain_to_internal_bonds[entity_id_to_example_chain[int(entity_id)]]
match boltz_type:
case "protein":
entities[int(entity_id)] = ProteinDefinition(
type=boltz_type,
sequence=sequence_from_residue_names(
polymer.comp_ids, polymer_type=boltz_type, nonstandard_handling="parentheses"
),
description=description_formatted,
bonds=internal_bonds,
)
case "rna":
entities[int(entity_id)] = RNADefinition(
type=boltz_type,
sequence=sequence_from_residue_names(
polymer.comp_ids, polymer_type=boltz_type, nonstandard_handling="parentheses"
),
description=description_formatted,
bonds=internal_bonds,
)
case "dna":
entities[int(entity_id)] = DNADefinition(
type=boltz_type,
sequence=sequence_from_residue_names(
polymer.comp_ids, polymer_type=boltz_type, nonstandard_handling="parentheses"
),
description=description_formatted,
bonds=internal_bonds,
)
case _:
msg = f"Unknown polymer type: {boltz_type}"
raise ValueError(msg)
elif int(entity_id) in branched_polymers:
branched_entity = branched_polymers[int(entity_id)]
chain_to_residue_numbers.update(branched_entity.chain_to_residue_numbers)
entities[int(entity_id)] = BranchedPolymerDefinition(
type="branched_polymer",
comp_ids=branched_entity.comp_ids,
bonds=branched_entity.bonds,
description=description_formatted,
)
else:
nonpolymer = nonpolymers[int(entity_id)]
chain_to_residue_numbers.update(nonpolymer.chain_to_residue_numbers)
entities[int(entity_id)] = LigandCCDDefinition(
type="ligand_ccd",
comp_id=nonpolymer.comp_id,
description=description_formatted,
)
chains = {
chain_id: ChainDefinition(
entity_idx=list(entities).index(entity_id), residue_numbers=chain_to_residue_numbers.get(chain_id)
)
for chain_id, entity_id in chain_to_entity_id.items()
if entity_id in entities
}
if len(chains) == 0:
msg = "No chains in structure"
raise ValueError(msg)
bonds = _get_bonds(mmcif, chains)
return StructureDefinition(entities=list(entities.values()), chains=chains, bonds=bonds)
def _get_polymer_bonds(mmcif: gemmi.cif.Block) -> Mapping[str, list[InternalBond]]:
chain_to_internal_bonds = defaultdict(list)
for conn_type, chain_id_1, seqnum_1, atom_name_1, chain_id_2, seqnum_2, atom_name_2 in mmcif.find(
"_struct_conn.",
[
"struct_conn",
"ptnr1_label_asym_id",
"ptnr1_label_seq_id",
"ptnr1_label_atom_id",
"ptnr2_label_asym_id",
"ptnr2_label_seq_id",
"ptnr2_label_atom_id",
],
):
if conn_type not in ["covale", "disulf"]:
continue
if chain_id_1 != chain_id_2:
continue
if seqnum_1 == "." or seqnum_2 == ".":
continue
chain_to_internal_bonds[chain_id_1].append(
InternalBond(
residue_index_1=int(seqnum_1) - 1,
atom_name_1=clean_string(atom_name_1),
residue_index_2=int(seqnum_2) - 1,
atom_name_2=clean_string(atom_name_2),
)
)
return chain_to_internal_bonds
def _get_bonds(mmcif: gemmi.cif.Block, chains: dict[str, ChainDefinition]) -> list[InterChainBond] | None:
columns = [
"conn_type_id",
"ptnr1_label_asym_id", # chain ID
"ptnr1_auth_seq_id",
"ptnr1_label_atom_id",
"ptnr2_label_asym_id",
"ptnr2_auth_seq_id",
"ptnr2_label_atom_id",
"ptnr1_symmetry",
"ptnr2_symmetry ",
]
bonds: list[InterChainBond] = []
for (
conn_type,
chain_id_1,
resnum_1,
atom_id_1,
chain_id_2,
resnum_2,
atom_id_2,
symmetry_1,
symmetry_2,
) in mmcif.find("_struct_conn.", columns):
if conn_type not in {"disulf", "covale"}:
continue
if symmetry_1 != symmetry_2:
continue
residue_index_1 = (chains[chain_id_1].residue_numbers or []).index(int(resnum_1)) # type: ignore[index]
residue_index_2 = (chains[chain_id_2].residue_numbers or []).index(int(resnum_2)) # type: ignore[index]
bonds.append(
InterChainBond(
chain_id_1=chain_id_1,
residue_index_1=residue_index_1,
atom_name_1=clean_string(atom_id_1),
chain_id_2=chain_id_2,
residue_index_2=residue_index_2,
atom_name_2=clean_string(atom_id_2),
bond_order=1,
)
)
return bonds if bonds else None