Source code for boltz_data.mol._validate

import numpy as np

from ._clash import get_atom_pairs_within_distance
from ._mol import BZBioMol, BZMol

MIN_BOND_LENGTH = 0.5  # Angstroms
MAX_BOND_LENGTH = 3.0  # Angstroms

MIN_CLASH_DISTANCE = 1  # Angstroms


class UnrealisticBondLengthError(ValueError):
    pass


[docs] def validate_bzmol(mol: BZBioMol, /) -> None: """Validate a BZBioMol instance for consistency.""" if mol.atom_coordinates is not None and mol.atom_resolved is not None: for bond_atoms, bond_length in zip( mol.bond_atoms, mol.bond_length if mol.bond_length is not None else [], strict=True ): if not mol.atom_resolved[bond_atoms[0]] or not mol.atom_resolved[bond_atoms[1]]: continue if bond_length < 0.5: msg = ( f"Unrealistic bond length: {bond_length} for bond " f"between {_format_atom(mol, bond_atoms[0])} and {_format_atom(mol, bond_atoms[1])}" ) raise UnrealisticBondLengthError(msg) if bond_length > 3: msg = ( f"Unrealistic bond length: {bond_length} for bond " f"between {_format_atom(mol, bond_atoms[0])} and {_format_atom(mol, bond_atoms[1])}" ) raise UnrealisticBondLengthError(msg) clashes = get_atom_pairs_within_distance(bzmol=mol, threshold=MIN_CLASH_DISTANCE) for i, j in clashes: msg = f"Clash detected between {_format_atom(mol, i)} and {_format_atom(mol, j)} of {_distance(mol, i, j)}" raise ValueError(msg)
def _format_residue(mol: BZBioMol, idx: int, /) -> str: residue_name = mol.residue_name[idx] residue_number = "" if mol.residue_number is None else mol.residue_number[idx] chain_id = mol.chain_id[mol.residue_chain[idx]] if mol.chain_id is not None else "?" return f"{residue_name}{residue_number} of chain {chain_id}" def _format_atom(mol: BZMol | BZBioMol, idx: int, /) -> str: if isinstance(mol, BZBioMol): residue_idx = int(mol.atom_residue[idx]) return f"{mol.atom_name[idx]} of {_format_residue(mol, residue_idx)}" return f"atom {idx}" def _distance(mol: BZBioMol, i: int, j: int, /) -> float: if mol.atom_coordinates is None: msg = "Cannot compute distance without atom coordinates." raise ValueError(msg) return float(np.linalg.norm(mol.atom_coordinates[i] - mol.atom_coordinates[j]))