import math
import numpy as np
from rdkit.Chem import GetPeriodicTable
from boltz_data.mol import BZMol
from ._cpk import CPK, get_color
from ._primitive import DrawnLineSegment, DrawnSphere, Primitive, Text
PERIODIC_TABLE = GetPeriodicTable()
# Visualization constants
ATOM_SPHERE_RADIUS = 0.15
ELEMENT_LABEL_FONT_SIZE = 0.15
ELEMENT_LABEL_COLOR = "#ffffff"
CHARGE_LABEL_OFFSET = np.array([0.17, -0.17, 0])
CHARGE_LABEL_FONT_SIZE_SINGLE = 0.25
CHARGE_LABEL_FONT_SIZE_MULTI = 0.15
SINGLE_BOND_WIDTH = 0.2
DOUBLE_BOND_WIDTH = 0.11
DOUBLE_BOND_OFFSET = 0.09
AROMATIC_RING_OFFSET = 0.25
AROMATIC_LINE_WIDTH = 0.1
[docs]
def ball_and_stick(bzmol: BZMol) -> list[Primitive]:
"""
Generate ball-and-stick representation primitives for a molecule.
Creates spheres for atoms and line segments for bonds, colored using CPK coloring.
Only includes resolved atoms and bonds between resolved atoms.
Args:
bzmol: The molecule to visualize.
Returns:
List of primitives (spheres, line segments, text) representing the molecule.
"""
bond_primitives = [
primitive
for bond_idx, bond in enumerate(bzmol.bond_atoms)
if bzmol.atom_resolved[bond[0]] and bzmol.atom_resolved[bond[1]]
for primitive in _create_bond_primitive(bzmol, bond_idx, bond[0], bond[1])
]
atom_primitives = [
primitive
for atom in range(bzmol.num_atoms)
if bzmol.atom_resolved[atom]
for primitive in _create_atom_primitives(bzmol, atom)
]
ring_primitives = []
if (bzmol.atom_coordinates[:, 2] == 0).all():
for ring in bzmol.aromatic_rings:
coords = bzmol.atom_coordinates[ring]
center = coords.mean(axis=0)
radius = (
np.linalg.norm(coords - center, axis=1).max() * math.cos(math.pi / len(ring)) - AROMATIC_RING_OFFSET
)
ring_primitives.append(
DrawnSphere(center=center, radius=radius, outline_color=CPK[6], outline_width=AROMATIC_LINE_WIDTH)
)
return bond_primitives + atom_primitives + ring_primitives
def _create_bond_primitive(bzmol: BZMol, bond: int, atom1: int, atom2: int, /) -> list[Primitive]:
coord1 = bzmol.atom_coordinates[atom1]
coord2 = bzmol.atom_coordinates[atom2]
elem1 = int(bzmol.atom_element[atom1])
elem2 = int(bzmol.atom_element[atom2])
bond_dir_screen = coord2[:2] - coord1[:2]
bond_transverse_screen = np.array([-bond_dir_screen[1], bond_dir_screen[0], 0])
bond_transverse_screen = bond_transverse_screen / np.linalg.norm(bond_transverse_screen)
if not bzmol.bond_is_aromatic[bond] and bzmol.bond_order[bond] == 2:
offset = DOUBLE_BOND_OFFSET
return [
DrawnLineSegment(
points=[coord1 + bond_transverse_screen * offset, coord2 + bond_transverse_screen * offset],
start_color=get_color(elem1),
end_color=get_color(elem2),
width=DOUBLE_BOND_WIDTH,
),
DrawnLineSegment(
points=[coord1 - bond_transverse_screen * offset, coord2 - bond_transverse_screen * offset],
start_color=get_color(elem1),
end_color=get_color(elem2),
width=DOUBLE_BOND_WIDTH,
),
]
return [
DrawnLineSegment(
points=[coord1, coord2], start_color=get_color(elem1), end_color=get_color(elem2), width=SINGLE_BOND_WIDTH
)
]
def _create_atom_primitives(bzmol: BZMol, atom: int, /) -> list[Primitive]:
coord = bzmol.atom_coordinates[atom]
element = int(bzmol.atom_element[atom])
charge = int(bzmol.atom_charge[atom])
color = get_color(element)
primitives: list[Primitive] = [DrawnSphere(center=coord, radius=ATOM_SPHERE_RADIUS, color=color)]
if element not in CPK:
primitives.append(
Text(
center=coord,
text=PERIODIC_TABLE.GetElementSymbol(element),
font_size=ELEMENT_LABEL_FONT_SIZE,
color=ELEMENT_LABEL_COLOR,
)
)
if charge != 0:
abs_charge = abs(charge)
primitives.append(
Text(
center=coord + CHARGE_LABEL_OFFSET,
text=_format_charge_text(charge),
font_size=CHARGE_LABEL_FONT_SIZE_SINGLE if abs_charge == 1 else CHARGE_LABEL_FONT_SIZE_MULTI,
color=color,
)
)
return primitives
def _format_charge_text(charge: int, /) -> str:
abs_charge = abs(charge)
sign = "+" if charge > 0 else "-"
magnitude = str(abs_charge) if abs_charge > 1 else ""
return f"{sign}{magnitude}"