Source code for boltz_data.draw._rdkit
"""Drawing code for rendering molecules using RDKit."""
from collections.abc import Callable
from typing import TypedDict
from rdkit import Chem
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D
class _DrawArgs(TypedDict, total=False):
highlightAtoms: list[int]
highlightAtomColors: dict[int, tuple[float, float, float]]
def _residue_number_alternating_scheme(atom: Chem.Atom) -> tuple[float, float, float] | None:
"""Color atoms based on whether their residue number is even or odd."""
if atom.GetPDBResidueInfo() is None:
return None
return (0.95, 0.9, 0.85) if atom.GetPDBResidueInfo().GetResidueNumber() % 2 == 0 else (0.8, 0.85, 0.9)
_ATOM_SCHEMES: dict[str, Callable[[Chem.Atom], tuple[float, float, float] | None]] = {
"residue_alternating": _residue_number_alternating_scheme
}
def _generate_highlight(
*, rdmol: Chem.Mol, atom_color: str | Callable[[Chem.Atom], tuple[float, float, float] | None]
) -> _DrawArgs:
highlight_atoms = []
highlight_atom_colors = {}
if isinstance(atom_color, str):
atom_color = _ATOM_SCHEMES[atom_color]
for atom in rdmol.GetAtoms(): # type: ignore[no-untyped-call]
color = atom_color(atom)
if color is None:
continue
highlight_atoms.append(atom.GetIdx())
highlight_atom_colors[atom.GetIdx()] = color
return {"highlightAtoms": highlight_atoms, "highlightAtomColors": highlight_atom_colors}
[docs]
def draw_rdmol(
rdmol: Chem.Mol,
/,
caption: str | None = None,
atom_color: str | Callable[[Chem.Atom], tuple[float, float, float] | None] | None = None,
) -> str:
"""
Draw an RDKit molecule to an SVG with stable coordinate generation.
Generates 2D coordinates canonically and renders the molecule as SVG.
Optionally highlights atoms with custom colors based on a scheme or callable.
Args:
rdmol: The RDKit molecule to draw.
caption: Optional caption to display below the molecule.
atom_color: Optional atom coloring scheme. Can be:
- "residue_alternating": Colors atoms by even/odd residue number
- A callable taking an Atom and returning RGB tuple or None
Returns:
SVG string representation of the molecule.
"""
# Ensure stable 2D coordinate generation
# Use rdDepictor for consistent results across platforms
rdDepictor.SetPreferCoordGen(val=True)
rdDepictor.Compute2DCoords(rdmol, canonOrient=True)
for atom in rdmol.GetAtoms(): # type: ignore[no-untyped-call]
if atom.GetPDBResidueInfo() is not None:
atom.SetProp("atomNote", atom.GetPDBResidueInfo().GetName())
d = rdMolDraw2D.MolDraw2DSVG(-1, -1)
kwargs: _DrawArgs = {}
if atom_color:
kwargs |= _generate_highlight(rdmol=rdmol, atom_color=atom_color)
d.DrawMolecule(rdmol, legend=caption or "", **kwargs)
d.FinishDrawing()
return d.GetDrawingText()