Source code for boltz_data.mol._visualize

"""SVG visualization for BZMol structures."""

from pathlib import Path

import numpy as np

from ._mol import BZBioMol, BZMol


def _draw_atom_box(
    svg_parts: list[str],
    /,
    *,
    x_offset: int,
    y_offset: int,
    box_width: int,
    box_height: int,
    atom_name: str,
    element: int,
    chain_color: str,
    has_coordinates: bool,
) -> None:
    """Draw a single atom box with text."""
    atom_color = _get_element_color(element)
    opacity = 1.0 if has_coordinates else 0.3
    text_opacity = 1.0 if has_coordinates else 0.35

    # Draw atom box
    svg_parts.append(
        f'<rect x="{x_offset}" y="{y_offset}" '
        f'width="{box_width}" height="{box_height}" '
        f'fill="{chain_color}" opacity="{opacity}"/>'
    )

    # Add atom name with colored text
    text_x = x_offset + box_width // 2
    text_y = y_offset + box_height // 2 + 5
    svg_parts.append(
        f'<text x="{text_x}" y="{text_y}" '
        f'text-anchor="middle" font-family="monospace" font-size="11" '
        f'fill="{atom_color}" font-weight="bold" opacity="{text_opacity}">'
        f"{atom_name}</text>"
    )


def _visualize_biomol(
    mol: BZBioMol,
    /,
    *,
    svg_parts: list[str],
    chain_colors: list[str],
    box_width: int,
    box_height: int,
    padding: int,
    y_offset: int,
) -> tuple[int, int]:
    """Visualize a BZBioMol with chains and residues."""
    max_width = 0

    for chain_idx in range(mol.num_chains):
        # Get chain ID if available
        chain_id = mol.chain_id[chain_idx] if mol.chain_id is not None else f"Chain {chain_idx + 1}"
        chain_color = chain_colors[chain_idx % len(chain_colors)]

        # Add chain label
        svg_parts.append(
            f'<text x="{padding}" y="{y_offset + 15}" '
            f'font-family="monospace" font-size="14" font-weight="bold">'
            f"Chain {chain_id}</text>"
        )
        y_offset += 25

        # Track the maximum width for this chain
        chain_max_width = 0

        residue_indices = np.where(mol.residue_chain == chain_idx)[0]

        # Draw residues in this chain
        for res_idx in residue_indices:
            residue_name = mol.residue_name[res_idx]

            # Add residue label
            svg_parts.append(
                f'<text x="{padding}" y="{y_offset + box_height // 2 + 5}" '
                f'font-family="monospace" font-size="11" fill="#666">'
                f"{residue_name}</text>"
            )

            # Draw atoms in this residue
            x_offset = padding + 50  # Leave space for residue label

            atom_indices = np.where(mol.atom_residue == res_idx)[0]

            for atom_idx in atom_indices:
                atom_name = mol.atom_name[atom_idx]
                element = int(mol.atom_element[atom_idx])

                # Check if atom has coordinates (default to True if no mask)
                has_coordinates = bool(mol.atom_resolved[atom_idx]) if mol.atom_resolved is not None else True

                _draw_atom_box(
                    svg_parts,
                    x_offset=x_offset,
                    y_offset=y_offset,
                    box_width=box_width,
                    box_height=box_height,
                    atom_name=str(atom_name),
                    element=element,
                    chain_color=chain_color,
                    has_coordinates=has_coordinates,
                )

                x_offset += box_width  # No spacing between boxes

            chain_max_width = max(chain_max_width, x_offset)
            y_offset += box_height  # No spacing between boxes

        # Add spacing between chains
        y_offset += 15
        max_width = max(max_width, chain_max_width)

    return max_width, y_offset


[docs] def bzmol_to_svg(mol: BZMol | BZBioMol, /, *, box_width: int = 60, box_height: int = 30, padding: int = 5) -> str: """ Generate an SVG visualization of a BZMol structure. Each chain is shown as a separate section with residues as rows and atoms as columns. Args: mol: The BZMol to visualize. box_width: Width of each atom box in pixels. box_height: Height of each atom box in pixels. padding: Padding around the entire diagram and between chains in pixels. Returns: SVG string representing the BZMol structure. """ svg_parts: list[str] = [] # Calculate dimensions y_offset = padding max_width = 0 # Define colors for different chains chain_colors = [ "#E8F4FD", "#FFF4E6", "#F3E5F5", "#E8F5E8", "#FFEBEE", "#F3E5FF", "#E1F5FE", "#FFF9C4", "#FCE4EC", "#E0F2F1", ] # Check if molecule has chains and residues (BZBioMol) if isinstance(mol, BZBioMol): # Normal visualization with chains and residues max_width, y_offset = _visualize_biomol( mol, svg_parts=svg_parts, chain_colors=chain_colors, box_width=box_width, box_height=box_height, padding=padding, y_offset=y_offset, ) else: # Simple visualization for molecules without chains/residues (e.g., from SMILES) x_offset = padding chain_color = chain_colors[0] # Use first color # Add label for molecule svg_parts.append( f'<text x="{padding}" y="{y_offset + 15}" ' f'font-family="monospace" font-size="14" font-weight="bold">' f"Molecule</text>" ) y_offset += 25 # Draw all atoms in a single row for atom_idx, atom_name in enumerate(mol.atom_name): element = int(mol.atom_element[atom_idx]) # Check if atom has coordinates has_coordinates = bool(mol.atom_resolved[atom_idx]) if mol.atom_resolved is not None else True _draw_atom_box( svg_parts, x_offset=x_offset, y_offset=y_offset, box_width=box_width, box_height=box_height, atom_name=atom_name, element=element, chain_color=chain_color, has_coordinates=has_coordinates, ) x_offset += box_width max_width = x_offset + padding y_offset += box_height # Create SVG container svg_width = max_width + padding svg_height = y_offset + padding svg = f'<svg width="{svg_width}" height="{svg_height}" xmlns="http://www.w3.org/2000/svg">\n' svg += f'<rect width="{svg_width}" height="{svg_height}" fill="white"/>\n' svg += "\n".join(svg_parts) svg += "\n</svg>" return svg
def _get_element_color(atomic_number: int) -> str: """ Get a color for an element based on its atomic number. Args: atomic_number: The atomic number of the element. Returns: Hex color code for the element. """ # Common element colors (CPK coloring scheme) - darker for better text visibility element_colors: dict[int, str] = { 1: "#808080", # H - gray (white text doesn't show well) 6: "#404040", # C - dark gray 7: "#1030D8", # N - blue 8: "#CC0000", # O - red 15: "#CC6600", # P - orange 16: "#CCCC00", # S - yellow-orange (pure yellow is hard to read) 9: "#50A020", # F - green 17: "#00CC00", # Cl - green 35: "#7C1919", # Br - dark red 53: "#6B006B", # I - purple } return element_colors.get(atomic_number, "#CCCCCC") # Default gray
[docs] def save_bzmol_svg( mol: BZMol, filepath: str | Path, /, *, box_width: int = 60, box_height: int = 30, padding: int = 5, ) -> None: """ Save a BZMol visualization as an SVG file. Args: mol: The BZMol to visualize. filepath: Path to save the SVG file. box_width: Width of each atom box in pixels. box_height: Height of each atom box in pixels. padding: Padding between boxes in pixels. """ svg_content = bzmol_to_svg(mol, box_width=box_width, box_height=box_height, padding=padding) filepath = Path(filepath) filepath.write_text(svg_content)