Source code for boltz_data.mol._op._assembly

import itertools
from collections.abc import Generator

import gemmi
import numpy as np

from boltz_data.mol._mol import BZBioMol
from boltz_data.transformation import AffineTransformation, IdentityTransformation

from ._chain_id import chain_id_generator
from ._concat import concat_bzmols
from ._subset import subset_bzmol
from ._transform import transform_bzmol


[docs] def iterate_assemblies(*, mmcif: gemmi.cif.Block, bzmol: BZBioMol, max_atoms: int | None = None) -> Generator[BZBioMol]: """Generate biological assemblies from an mmCIF file and asymmetric unit BZBioMol.""" valid_assembly_ids: list[str] = [] for (assembly_id,) in mmcif.find("_pdbx_struct_assembly.", ["id"]): valid_assembly_ids.append(assembly_id) columns = [ "id", "type", "matrix[1][1]", "matrix[1][2]", "matrix[1][3]", "vector[1]", "matrix[2][1]", "matrix[2][2]", "matrix[2][3]", "vector[2]", "matrix[3][1]", "matrix[3][2]", "matrix[3][3]", "vector[3]", ] oper_id_to_symm_operation: dict[str, IdentityTransformation | AffineTransformation] = {} for op_id, op_type, m1, m2, m3, v1, m4, m5, m6, v2, m7, m8, m9, v3 in mmcif.find( "_pdbx_struct_oper_list.", columns ): if op_type == "'identity operation'": oper_id_to_symm_operation[str(op_id)] = IdentityTransformation() else: oper_id_to_symm_operation[str(op_id)] = AffineTransformation( matrix=np.array([[m1, m2, m3], [m4, m5, m6], [m7, m8, m9]], dtype=np.float32), translation=np.array([v1, v2, v3], dtype=np.float32), ) for assembly_id, oper_expression, chain_id_str in mmcif.find( "_pdbx_struct_assembly_gen.", ["assembly_id", "oper_expression", "asym_id_list"] ): if assembly_id not in valid_assembly_ids: continue used_chain_id = chain_id_str.rstrip(";").lstrip(";").split(",") asymmetric_unit_copies: list[BZBioMol] = [] assembly_bzmol = subset_bzmol(bzmol, chain_ids=used_chain_id) available_chain_id = (chain_id for chain_id in chain_id_generator() if chain_id not in used_chain_id) operations = parse_oper_expression(oper_expression) num_atoms = assembly_bzmol.num_atoms * len(operations) if max_atoms and num_atoms > max_atoms: continue for oper_ids in operations: transformation = IdentityTransformation() for oper_id in oper_ids: transformation = transformation @ oper_id_to_symm_operation[oper_id] # type: ignore[assignment] if isinstance(transformation, IdentityTransformation): transformed_bzmol = assembly_bzmol else: transformed_bzmol = transform_bzmol( assembly_bzmol, # Transform coordinates atom_coordinates=transformation.transform_points(assembly_bzmol.atom_coordinates) if bzmol.atom_coordinates is not None else None, # Increment chain IDs chain_id=np.array(list(itertools.islice(available_chain_id, len(assembly_bzmol.chain_id)))), ) asymmetric_unit_copies.append(transformed_bzmol) yield concat_bzmols(*asymmetric_unit_copies)
def parse_oper_expression(oper_expression: str, /) -> list[list[str]]: if oper_expression.startswith("'") and oper_expression.endswith("'"): oper_expression = oper_expression[1:-1] if oper_expression.startswith("(") and oper_expression.endswith(")"): oper_expression = oper_expression[1:-1] groups = [] for group in oper_expression.split(")("): operations = [] for subgroup in group.split(","): if "-" in subgroup: start, end = map(int, subgroup.split("-")) operations.extend([str(i) for i in range(start, end + 1)]) else: operations.append(subgroup) groups.append(operations) return [list(ops) for ops in itertools.product(*groups)]