from collections import defaultdict
import numpy as np
import numpy.typing as npt
from pydantic import BaseModel
from scipy.spatial import KDTree
from boltz_data.geom import BoundingBoxes, Spheres
from boltz_data.pydantic import NDArray
from ._mol import BZBioMol
def get_residue_bounding_boxes(bzmol: BZBioMol, /) -> BoundingBoxes:
resolved_atom_coordinates = bzmol.atom_coordinates[bzmol.atom_resolved]
resolved_atom_residue = bzmol.atom_residue[bzmol.atom_resolved]
# Initialize with inf/-inf for proper min/max aggregation
mins = np.full((bzmol.num_residues, 3), np.inf)
maxs = np.full((bzmol.num_residues, 3), -np.inf)
# Use np.minimum.at / np.maximum.at for in-place min/max by group
np.minimum.at(mins, resolved_atom_residue, resolved_atom_coordinates)
np.maximum.at(maxs, resolved_atom_residue, resolved_atom_coordinates)
return BoundingBoxes(min=mins, max=maxs)
def get_overlapping_residues(bzmol: BZBioMol, /, threshold: float) -> npt.NDArray[np.int32]:
bounding_boxes = get_residue_bounding_boxes(bzmol)
bounding_boxes = bounding_boxes.grow(threshold / 2.0)
mins, maxs = bounding_boxes.min, bounding_boxes.max
chains = bzmol.residue_chain
# Sort by x-axis minimum
order = np.argsort(mins[:, 0])
pairs = []
for idx, i in enumerate(order):
# Check against all boxes that start before this one ends
for j in order[idx + 1 :]:
# If j starts after i ends in x, no more overlaps possible
if mins[j, 0] > maxs[i, 0]:
break
if chains[i] == chains[j]:
continue
# Check y and z overlap
if (
maxs[i, 1] >= mins[j, 1]
and mins[i, 1] <= maxs[j, 1]
and maxs[i, 2] >= mins[j, 2]
and mins[i, 2] <= maxs[j, 2]
):
pairs.append([i, j] if i < j else [j, i])
return np.array(pairs)
[docs]
def get_residue_bounding_spheres_around_centroid(bzmol: BZBioMol, /) -> Spheres:
"""Calculate bounding spheres for each residue centered at the residue centroid."""
centers = []
radii = []
for residue in range(bzmol.num_residues):
atom_mask = (bzmol.atom_residue == residue) & bzmol.atom_resolved
coords = bzmol.atom_coordinates[atom_mask]
if len(coords) == 0:
continue
center = coords.mean(axis=0)
radius = np.linalg.norm(coords - center, axis=-1).max()
centers.append(center)
radii.append(radius)
return Spheres(center=np.array(centers), radius=np.array(radii))
class Interface(BaseModel):
bzmol: BZBioMol
chain1: int
chain2: int
num_atoms_within_5a: int
residues1: NDArray[np.int32]
residues2: NDArray[np.int32]
@property
def relative_residues1(self) -> npt.NDArray[np.int32] | npt.NDArray[np.int64]:
chain1_residues = np.where(self.bzmol.residue_chain == self.chain1)[0]
return np.where(chain1_residues[:, None] == self.residues1[None, :])[0]
@property
def relative_residues2(self) -> npt.NDArray[np.int32] | npt.NDArray[np.int64]:
chain2_residues = np.where(self.bzmol.residue_chain == self.chain2)[0]
return np.where(chain2_residues[:, None] == self.residues2[None, :])[0]
model_config = {"arbitrary_types_allowed": True}
[docs]
def get_molecular_interfaces(bzmol: BZBioMol, /, threshold: float = 5.0) -> list[Interface]:
"""
Find interfaces between chains using atoms within threshold distance.
Uses a two-pass approach:
1. Find residue pairs with overlapping bounding boxes (fast sweep-and-prune)
2. Check atom distances only for those residue pairs (KDTree per pair)
This leverages residue grouping to avoid O(n²) atom comparisons.
"""
# Get overlapping residue pairs (different chains only)
overlapping_residues = get_overlapping_residues(bzmol, threshold=threshold)
if len(overlapping_residues) == 0:
return []
# Count contacts per chain pair
contacts: dict[tuple[int, int], list[tuple[int, int]]] = defaultdict(list)
total_n_atoms = 0
# For each overlapping residue pair, check atom distances
for res_i, res_j in overlapping_residues:
chain_i = bzmol.residue_chain[res_i]
chain_j = bzmol.residue_chain[res_j]
# Get atoms for each residue
atoms_i_mask = (bzmol.atom_residue == res_i) & bzmol.atom_resolved
atoms_j_mask = (bzmol.atom_residue == res_j) & bzmol.atom_resolved
coords_i = bzmol.atom_coordinates[atoms_i_mask]
coords_j = bzmol.atom_coordinates[atoms_j_mask]
if len(coords_i) == 0 or len(coords_j) == 0:
continue
# Build KDTree for one residue, query with other
tree = KDTree(coords_i)
pairs = tree.query_ball_point(coords_j, threshold)
# Count total contacts
num_contacts = sum(len(neighbors) for neighbors in pairs)
if num_contacts > 0:
total_n_atoms += num_contacts
if chain_i < chain_j:
chain_pair = (chain_i, chain_j)
contacts[chain_pair].append((res_i, res_j))
else:
chain_pair = (chain_j, chain_i)
contacts[chain_pair].append((res_j, res_i))
# Convert to Interface objects
return [
Interface(
bzmol=bzmol,
chain1=c1,
chain2=c2,
num_atoms_within_5a=total_n_atoms,
residues1=np.array([res[0] for res in residue_pairs]),
residues2=np.array([res[1] for res in residue_pairs]),
)
for (c1, c2), residue_pairs in contacts.items()
]