Source code for boltz_data.draw.mol3d._ball_and_stick

import math

import numpy as np
from rdkit.Chem import GetPeriodicTable

from boltz_data.mol import BZMol

from ._cpk import CPK, get_color
from ._primitive import DrawnLineSegment, DrawnSphere, Primitive, Text

PERIODIC_TABLE = GetPeriodicTable()

# Visualization constants
ATOM_SPHERE_RADIUS = 0.15
ELEMENT_LABEL_FONT_SIZE = 0.15
ELEMENT_LABEL_COLOR = "#ffffff"
CHARGE_LABEL_OFFSET = np.array([0.17, -0.17, 0])
CHARGE_LABEL_FONT_SIZE_SINGLE = 0.25
CHARGE_LABEL_FONT_SIZE_MULTI = 0.15

SINGLE_BOND_WIDTH = 0.2
DOUBLE_BOND_WIDTH = 0.11
DOUBLE_BOND_OFFSET = 0.09

AROMATIC_RING_OFFSET = 0.25
AROMATIC_LINE_WIDTH = 0.1


[docs] def ball_and_stick(bzmol: BZMol) -> list[Primitive]: """ Generate ball-and-stick representation primitives for a molecule. Creates spheres for atoms and line segments for bonds, colored using CPK coloring. Only includes resolved atoms and bonds between resolved atoms. Args: bzmol: The molecule to visualize. Returns: List of primitives (spheres, line segments, text) representing the molecule. """ bond_primitives = [ primitive for bond_idx, bond in enumerate(bzmol.bond_atoms) if bzmol.atom_resolved[bond[0]] and bzmol.atom_resolved[bond[1]] for primitive in _create_bond_primitive(bzmol, bond_idx, bond[0], bond[1]) ] atom_primitives = [ primitive for atom in range(bzmol.num_atoms) if bzmol.atom_resolved[atom] for primitive in _create_atom_primitives(bzmol, atom) ] ring_primitives = [] if (bzmol.atom_coordinates[:, 2] == 0).all(): for ring in bzmol.aromatic_rings: coords = bzmol.atom_coordinates[ring] center = coords.mean(axis=0) radius = ( np.linalg.norm(coords - center, axis=1).max() * math.cos(math.pi / len(ring)) - AROMATIC_RING_OFFSET ) ring_primitives.append( DrawnSphere(center=center, radius=radius, outline_color=CPK[6], outline_width=AROMATIC_LINE_WIDTH) ) return bond_primitives + atom_primitives + ring_primitives
def _create_bond_primitive(bzmol: BZMol, bond: int, atom1: int, atom2: int, /) -> list[Primitive]: coord1 = bzmol.atom_coordinates[atom1] coord2 = bzmol.atom_coordinates[atom2] elem1 = int(bzmol.atom_element[atom1]) elem2 = int(bzmol.atom_element[atom2]) bond_dir_screen = coord2[:2] - coord1[:2] bond_transverse_screen = np.array([-bond_dir_screen[1], bond_dir_screen[0], 0]) bond_transverse_screen = bond_transverse_screen / np.linalg.norm(bond_transverse_screen) if not bzmol.bond_is_aromatic[bond] and bzmol.bond_order[bond] == 2: offset = DOUBLE_BOND_OFFSET return [ DrawnLineSegment( points=[coord1 + bond_transverse_screen * offset, coord2 + bond_transverse_screen * offset], start_color=get_color(elem1), end_color=get_color(elem2), width=DOUBLE_BOND_WIDTH, ), DrawnLineSegment( points=[coord1 - bond_transverse_screen * offset, coord2 - bond_transverse_screen * offset], start_color=get_color(elem1), end_color=get_color(elem2), width=DOUBLE_BOND_WIDTH, ), ] return [ DrawnLineSegment( points=[coord1, coord2], start_color=get_color(elem1), end_color=get_color(elem2), width=SINGLE_BOND_WIDTH ) ] def _create_atom_primitives(bzmol: BZMol, atom: int, /) -> list[Primitive]: coord = bzmol.atom_coordinates[atom] element = int(bzmol.atom_element[atom]) charge = int(bzmol.atom_charge[atom]) color = get_color(element) primitives: list[Primitive] = [DrawnSphere(center=coord, radius=ATOM_SPHERE_RADIUS, color=color)] if element not in CPK: primitives.append( Text( center=coord, text=PERIODIC_TABLE.GetElementSymbol(element), font_size=ELEMENT_LABEL_FONT_SIZE, color=ELEMENT_LABEL_COLOR, ) ) if charge != 0: abs_charge = abs(charge) primitives.append( Text( center=coord + CHARGE_LABEL_OFFSET, text=_format_charge_text(charge), font_size=CHARGE_LABEL_FONT_SIZE_SINGLE if abs_charge == 1 else CHARGE_LABEL_FONT_SIZE_MULTI, color=color, ) ) return primitives def _format_charge_text(charge: int, /) -> str: abs_charge = abs(charge) sign = "+" if charge > 0 else "-" magnitude = str(abs_charge) if abs_charge > 1 else "" return f"{sign}{magnitude}"