"""SVG visualization for BZMol structures."""
from pathlib import Path
import numpy as np
from ._mol import BZBioMol, BZMol
def _draw_atom_box(
svg_parts: list[str],
/,
*,
x_offset: int,
y_offset: int,
box_width: int,
box_height: int,
atom_name: str,
element: int,
chain_color: str,
has_coordinates: bool,
) -> None:
"""Draw a single atom box with text."""
atom_color = _get_element_color(element)
opacity = 1.0 if has_coordinates else 0.3
text_opacity = 1.0 if has_coordinates else 0.35
# Draw atom box
svg_parts.append(
f'<rect x="{x_offset}" y="{y_offset}" '
f'width="{box_width}" height="{box_height}" '
f'fill="{chain_color}" opacity="{opacity}"/>'
)
# Add atom name with colored text
text_x = x_offset + box_width // 2
text_y = y_offset + box_height // 2 + 5
svg_parts.append(
f'<text x="{text_x}" y="{text_y}" '
f'text-anchor="middle" font-family="monospace" font-size="11" '
f'fill="{atom_color}" font-weight="bold" opacity="{text_opacity}">'
f"{atom_name}</text>"
)
def _visualize_biomol(
mol: BZBioMol,
/,
*,
svg_parts: list[str],
chain_colors: list[str],
box_width: int,
box_height: int,
padding: int,
y_offset: int,
) -> tuple[int, int]:
"""Visualize a BZBioMol with chains and residues."""
max_width = 0
for chain_idx in range(mol.num_chains):
# Get chain ID if available
chain_id = mol.chain_id[chain_idx] if mol.chain_id is not None else f"Chain {chain_idx + 1}"
chain_color = chain_colors[chain_idx % len(chain_colors)]
# Add chain label
svg_parts.append(
f'<text x="{padding}" y="{y_offset + 15}" '
f'font-family="monospace" font-size="14" font-weight="bold">'
f"Chain {chain_id}</text>"
)
y_offset += 25
# Track the maximum width for this chain
chain_max_width = 0
residue_indices = np.where(mol.residue_chain == chain_idx)[0]
# Draw residues in this chain
for res_idx in residue_indices:
residue_name = mol.residue_name[res_idx]
# Add residue label
svg_parts.append(
f'<text x="{padding}" y="{y_offset + box_height // 2 + 5}" '
f'font-family="monospace" font-size="11" fill="#666">'
f"{residue_name}</text>"
)
# Draw atoms in this residue
x_offset = padding + 50 # Leave space for residue label
atom_indices = np.where(mol.atom_residue == res_idx)[0]
for atom_idx in atom_indices:
atom_name = mol.atom_name[atom_idx]
element = int(mol.atom_element[atom_idx])
# Check if atom has coordinates (default to True if no mask)
has_coordinates = bool(mol.atom_resolved[atom_idx]) if mol.atom_resolved is not None else True
_draw_atom_box(
svg_parts,
x_offset=x_offset,
y_offset=y_offset,
box_width=box_width,
box_height=box_height,
atom_name=str(atom_name),
element=element,
chain_color=chain_color,
has_coordinates=has_coordinates,
)
x_offset += box_width # No spacing between boxes
chain_max_width = max(chain_max_width, x_offset)
y_offset += box_height # No spacing between boxes
# Add spacing between chains
y_offset += 15
max_width = max(max_width, chain_max_width)
return max_width, y_offset
[docs]
def bzmol_to_svg(mol: BZMol | BZBioMol, /, *, box_width: int = 60, box_height: int = 30, padding: int = 5) -> str:
"""
Generate an SVG visualization of a BZMol structure.
Each chain is shown as a separate section with residues as rows and atoms as columns.
Args:
mol: The BZMol to visualize.
box_width: Width of each atom box in pixels.
box_height: Height of each atom box in pixels.
padding: Padding around the entire diagram and between chains in pixels.
Returns:
SVG string representing the BZMol structure.
"""
svg_parts: list[str] = []
# Calculate dimensions
y_offset = padding
max_width = 0
# Define colors for different chains
chain_colors = [
"#E8F4FD",
"#FFF4E6",
"#F3E5F5",
"#E8F5E8",
"#FFEBEE",
"#F3E5FF",
"#E1F5FE",
"#FFF9C4",
"#FCE4EC",
"#E0F2F1",
]
# Check if molecule has chains and residues (BZBioMol)
if isinstance(mol, BZBioMol):
# Normal visualization with chains and residues
max_width, y_offset = _visualize_biomol(
mol,
svg_parts=svg_parts,
chain_colors=chain_colors,
box_width=box_width,
box_height=box_height,
padding=padding,
y_offset=y_offset,
)
else:
# Simple visualization for molecules without chains/residues (e.g., from SMILES)
x_offset = padding
chain_color = chain_colors[0] # Use first color
# Add label for molecule
svg_parts.append(
f'<text x="{padding}" y="{y_offset + 15}" '
f'font-family="monospace" font-size="14" font-weight="bold">'
f"Molecule</text>"
)
y_offset += 25
# Draw all atoms in a single row
for atom_idx, atom_name in enumerate(mol.atom_name):
element = int(mol.atom_element[atom_idx])
# Check if atom has coordinates
has_coordinates = bool(mol.atom_resolved[atom_idx]) if mol.atom_resolved is not None else True
_draw_atom_box(
svg_parts,
x_offset=x_offset,
y_offset=y_offset,
box_width=box_width,
box_height=box_height,
atom_name=atom_name,
element=element,
chain_color=chain_color,
has_coordinates=has_coordinates,
)
x_offset += box_width
max_width = x_offset + padding
y_offset += box_height
# Create SVG container
svg_width = max_width + padding
svg_height = y_offset + padding
svg = f'<svg width="{svg_width}" height="{svg_height}" xmlns="http://www.w3.org/2000/svg">\n'
svg += f'<rect width="{svg_width}" height="{svg_height}" fill="white"/>\n'
svg += "\n".join(svg_parts)
svg += "\n</svg>"
return svg
def _get_element_color(atomic_number: int) -> str:
"""
Get a color for an element based on its atomic number.
Args:
atomic_number: The atomic number of the element.
Returns:
Hex color code for the element.
"""
# Common element colors (CPK coloring scheme) - darker for better text visibility
element_colors: dict[int, str] = {
1: "#808080", # H - gray (white text doesn't show well)
6: "#404040", # C - dark gray
7: "#1030D8", # N - blue
8: "#CC0000", # O - red
15: "#CC6600", # P - orange
16: "#CCCC00", # S - yellow-orange (pure yellow is hard to read)
9: "#50A020", # F - green
17: "#00CC00", # Cl - green
35: "#7C1919", # Br - dark red
53: "#6B006B", # I - purple
}
return element_colors.get(atomic_number, "#CCCCCC") # Default gray
[docs]
def save_bzmol_svg(
mol: BZMol,
filepath: str | Path,
/,
*,
box_width: int = 60,
box_height: int = 30,
padding: int = 5,
) -> None:
"""
Save a BZMol visualization as an SVG file.
Args:
mol: The BZMol to visualize.
filepath: Path to save the SVG file.
box_width: Width of each atom box in pixels.
box_height: Height of each atom box in pixels.
padding: Padding between boxes in pixels.
"""
svg_content = bzmol_to_svg(mol, box_width=box_width, box_height=box_height, padding=padding)
filepath = Path(filepath)
filepath.write_text(svg_content)