import itertools
from collections.abc import Generator
import gemmi
import numpy as np
from boltz_data.mol._mol import BZBioMol
from boltz_data.transformation import AffineTransformation, IdentityTransformation
from ._chain_id import chain_id_generator
from ._concat import concat_bzmols
from ._subset import subset_bzmol
from ._transform import transform_bzmol
[docs]
def iterate_assemblies(*, mmcif: gemmi.cif.Block, bzmol: BZBioMol, max_atoms: int | None = None) -> Generator[BZBioMol]:
"""Generate biological assemblies from an mmCIF file and asymmetric unit BZBioMol."""
valid_assembly_ids: list[str] = []
for (assembly_id,) in mmcif.find("_pdbx_struct_assembly.", ["id"]):
valid_assembly_ids.append(assembly_id)
columns = [
"id",
"type",
"matrix[1][1]",
"matrix[1][2]",
"matrix[1][3]",
"vector[1]",
"matrix[2][1]",
"matrix[2][2]",
"matrix[2][3]",
"vector[2]",
"matrix[3][1]",
"matrix[3][2]",
"matrix[3][3]",
"vector[3]",
]
oper_id_to_symm_operation: dict[str, IdentityTransformation | AffineTransformation] = {}
for op_id, op_type, m1, m2, m3, v1, m4, m5, m6, v2, m7, m8, m9, v3 in mmcif.find(
"_pdbx_struct_oper_list.", columns
):
if op_type == "'identity operation'":
oper_id_to_symm_operation[str(op_id)] = IdentityTransformation()
else:
oper_id_to_symm_operation[str(op_id)] = AffineTransformation(
matrix=np.array([[m1, m2, m3], [m4, m5, m6], [m7, m8, m9]], dtype=np.float32),
translation=np.array([v1, v2, v3], dtype=np.float32),
)
for assembly_id, oper_expression, chain_id_str in mmcif.find(
"_pdbx_struct_assembly_gen.", ["assembly_id", "oper_expression", "asym_id_list"]
):
if assembly_id not in valid_assembly_ids:
continue
used_chain_id = chain_id_str.rstrip(";").lstrip(";").split(",")
asymmetric_unit_copies: list[BZBioMol] = []
assembly_bzmol = subset_bzmol(bzmol, chain_ids=used_chain_id)
available_chain_id = (chain_id for chain_id in chain_id_generator() if chain_id not in used_chain_id)
operations = parse_oper_expression(oper_expression)
num_atoms = assembly_bzmol.num_atoms * len(operations)
if max_atoms and num_atoms > max_atoms:
continue
for oper_ids in operations:
transformation = IdentityTransformation()
for oper_id in oper_ids:
transformation = transformation @ oper_id_to_symm_operation[oper_id] # type: ignore[assignment]
if isinstance(transformation, IdentityTransformation):
transformed_bzmol = assembly_bzmol
else:
transformed_bzmol = transform_bzmol(
assembly_bzmol,
# Transform coordinates
atom_coordinates=transformation.transform_points(assembly_bzmol.atom_coordinates)
if bzmol.atom_coordinates is not None
else None,
# Increment chain IDs
chain_id=np.array(list(itertools.islice(available_chain_id, len(assembly_bzmol.chain_id)))),
)
asymmetric_unit_copies.append(transformed_bzmol)
yield concat_bzmols(*asymmetric_unit_copies)
def parse_oper_expression(oper_expression: str, /) -> list[list[str]]:
if oper_expression.startswith("'") and oper_expression.endswith("'"):
oper_expression = oper_expression[1:-1]
if oper_expression.startswith("(") and oper_expression.endswith(")"):
oper_expression = oper_expression[1:-1]
groups = []
for group in oper_expression.split(")("):
operations = []
for subgroup in group.split(","):
if "-" in subgroup:
start, end = map(int, subgroup.split("-"))
operations.extend([str(i) for i in range(start, end + 1)])
else:
operations.append(subgroup)
groups.append(operations)
return [list(ops) for ops in itertools.product(*groups)]