Source code for boltz_data.draw._rdkit

"""Drawing code for rendering molecules using RDKit."""

from collections.abc import Callable
from typing import TypedDict

from rdkit import Chem
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D


class _DrawArgs(TypedDict, total=False):
    highlightAtoms: list[int]
    highlightAtomColors: dict[int, tuple[float, float, float]]


def _residue_number_alternating_scheme(atom: Chem.Atom) -> tuple[float, float, float] | None:
    """Color atoms based on whether their residue number is even or odd."""
    if atom.GetPDBResidueInfo() is None:
        return None
    return (0.95, 0.9, 0.85) if atom.GetPDBResidueInfo().GetResidueNumber() % 2 == 0 else (0.8, 0.85, 0.9)


_ATOM_SCHEMES: dict[str, Callable[[Chem.Atom], tuple[float, float, float] | None]] = {
    "residue_alternating": _residue_number_alternating_scheme
}


def _generate_highlight(
    *, rdmol: Chem.Mol, atom_color: str | Callable[[Chem.Atom], tuple[float, float, float] | None]
) -> _DrawArgs:
    highlight_atoms = []
    highlight_atom_colors = {}
    if isinstance(atom_color, str):
        atom_color = _ATOM_SCHEMES[atom_color]
    for atom in rdmol.GetAtoms():  # type: ignore[no-untyped-call]
        color = atom_color(atom)
        if color is None:
            continue
        highlight_atoms.append(atom.GetIdx())
        highlight_atom_colors[atom.GetIdx()] = color
    return {"highlightAtoms": highlight_atoms, "highlightAtomColors": highlight_atom_colors}


[docs] def draw_rdmol( rdmol: Chem.Mol, /, caption: str | None = None, atom_color: str | Callable[[Chem.Atom], tuple[float, float, float] | None] | None = None, ) -> str: """ Draw an RDKit molecule to an SVG with stable coordinate generation. Generates 2D coordinates canonically and renders the molecule as SVG. Optionally highlights atoms with custom colors based on a scheme or callable. Args: rdmol: The RDKit molecule to draw. caption: Optional caption to display below the molecule. atom_color: Optional atom coloring scheme. Can be: - "residue_alternating": Colors atoms by even/odd residue number - A callable taking an Atom and returning RGB tuple or None Returns: SVG string representation of the molecule. """ # Ensure stable 2D coordinate generation # Use rdDepictor for consistent results across platforms rdDepictor.SetPreferCoordGen(val=True) rdDepictor.Compute2DCoords(rdmol, canonOrient=True) for atom in rdmol.GetAtoms(): # type: ignore[no-untyped-call] if atom.GetPDBResidueInfo() is not None: atom.SetProp("atomNote", atom.GetPDBResidueInfo().GetName()) d = rdMolDraw2D.MolDraw2DSVG(-1, -1) kwargs: _DrawArgs = {} if atom_color: kwargs |= _generate_highlight(rdmol=rdmol, atom_color=atom_color) d.DrawMolecule(rdmol, legend=caption or "", **kwargs) d.FinishDrawing() return d.GetDrawingText()