Source code for boltz_data.mol._geom

from collections import defaultdict

import numpy as np
import numpy.typing as npt
from pydantic import BaseModel
from scipy.spatial import KDTree

from boltz_data.geom import BoundingBoxes, Spheres
from boltz_data.pydantic import NDArray

from ._mol import BZBioMol


def get_residue_bounding_boxes(bzmol: BZBioMol, /) -> BoundingBoxes:
    resolved_atom_coordinates = bzmol.atom_coordinates[bzmol.atom_resolved]
    resolved_atom_residue = bzmol.atom_residue[bzmol.atom_resolved]

    # Initialize with inf/-inf for proper min/max aggregation
    mins = np.full((bzmol.num_residues, 3), np.inf)
    maxs = np.full((bzmol.num_residues, 3), -np.inf)

    # Use np.minimum.at / np.maximum.at for in-place min/max by group
    np.minimum.at(mins, resolved_atom_residue, resolved_atom_coordinates)
    np.maximum.at(maxs, resolved_atom_residue, resolved_atom_coordinates)

    return BoundingBoxes(min=mins, max=maxs)


def get_overlapping_residues(bzmol: BZBioMol, /, threshold: float) -> npt.NDArray[np.int32]:
    bounding_boxes = get_residue_bounding_boxes(bzmol)
    bounding_boxes = bounding_boxes.grow(threshold / 2.0)

    mins, maxs = bounding_boxes.min, bounding_boxes.max
    chains = bzmol.residue_chain

    # Sort by x-axis minimum
    order = np.argsort(mins[:, 0])

    pairs = []
    for idx, i in enumerate(order):
        # Check against all boxes that start before this one ends
        for j in order[idx + 1 :]:
            # If j starts after i ends in x, no more overlaps possible
            if mins[j, 0] > maxs[i, 0]:
                break

            if chains[i] == chains[j]:
                continue

            # Check y and z overlap
            if (
                maxs[i, 1] >= mins[j, 1]
                and mins[i, 1] <= maxs[j, 1]
                and maxs[i, 2] >= mins[j, 2]
                and mins[i, 2] <= maxs[j, 2]
            ):
                pairs.append([i, j] if i < j else [j, i])

    return np.array(pairs)


[docs] def get_residue_bounding_spheres_around_centroid(bzmol: BZBioMol, /) -> Spheres: """Calculate bounding spheres for each residue centered at the residue centroid.""" centers = [] radii = [] for residue in range(bzmol.num_residues): atom_mask = (bzmol.atom_residue == residue) & bzmol.atom_resolved coords = bzmol.atom_coordinates[atom_mask] if len(coords) == 0: continue center = coords.mean(axis=0) radius = np.linalg.norm(coords - center, axis=-1).max() centers.append(center) radii.append(radius) return Spheres(center=np.array(centers), radius=np.array(radii))
class Interface(BaseModel): bzmol: BZBioMol chain1: int chain2: int num_atoms_within_5a: int residues1: NDArray[np.int32] residues2: NDArray[np.int32] @property def relative_residues1(self) -> npt.NDArray[np.int32] | npt.NDArray[np.int64]: chain1_residues = np.where(self.bzmol.residue_chain == self.chain1)[0] return np.where(chain1_residues[:, None] == self.residues1[None, :])[0] @property def relative_residues2(self) -> npt.NDArray[np.int32] | npt.NDArray[np.int64]: chain2_residues = np.where(self.bzmol.residue_chain == self.chain2)[0] return np.where(chain2_residues[:, None] == self.residues2[None, :])[0] model_config = {"arbitrary_types_allowed": True}
[docs] def get_molecular_interfaces(bzmol: BZBioMol, /, threshold: float = 5.0) -> list[Interface]: """ Find interfaces between chains using atoms within threshold distance. Uses a two-pass approach: 1. Find residue pairs with overlapping bounding boxes (fast sweep-and-prune) 2. Check atom distances only for those residue pairs (KDTree per pair) This leverages residue grouping to avoid O(n²) atom comparisons. """ # Get overlapping residue pairs (different chains only) overlapping_residues = get_overlapping_residues(bzmol, threshold=threshold) if len(overlapping_residues) == 0: return [] # Count contacts per chain pair contacts: dict[tuple[int, int], list[tuple[int, int]]] = defaultdict(list) total_n_atoms = 0 # For each overlapping residue pair, check atom distances for res_i, res_j in overlapping_residues: chain_i = bzmol.residue_chain[res_i] chain_j = bzmol.residue_chain[res_j] # Get atoms for each residue atoms_i_mask = (bzmol.atom_residue == res_i) & bzmol.atom_resolved atoms_j_mask = (bzmol.atom_residue == res_j) & bzmol.atom_resolved coords_i = bzmol.atom_coordinates[atoms_i_mask] coords_j = bzmol.atom_coordinates[atoms_j_mask] if len(coords_i) == 0 or len(coords_j) == 0: continue # Build KDTree for one residue, query with other tree = KDTree(coords_i) pairs = tree.query_ball_point(coords_j, threshold) # Count total contacts num_contacts = sum(len(neighbors) for neighbors in pairs) if num_contacts > 0: total_n_atoms += num_contacts if chain_i < chain_j: chain_pair = (chain_i, chain_j) contacts[chain_pair].append((res_i, res_j)) else: chain_pair = (chain_j, chain_i) contacts[chain_pair].append((res_j, res_i)) # Convert to Interface objects return [ Interface( bzmol=bzmol, chain1=c1, chain2=c2, num_atoms_within_5a=total_n_atoms, residues1=np.array([res[0] for res in residue_pairs]), residues2=np.array([res[1] for res in residue_pairs]), ) for (c1, c2), residue_pairs in contacts.items() ]