Source code for boltz_data.draw.mol3d._render

import numpy as np
import numpy.typing as npt
import svg

from boltz_data.draw.color import as_hex_color, darken
from boltz_data.geom import BoundingBoxes

from ._primitive import DrawnLineSegment, DrawnSphere, Primitive, Text

# Rendering constants
DEPTH_DARKEN_FACTOR = 0.6
SVG_SCALE = 15


[docs] def draw_3d_svg(*primitives: Primitive, padding: float = 2, depth: bool = True) -> str: """ Render 3D primitives to an SVG string with depth-based shading. Projects 3D primitives onto 2D using orthographic projection (dropping Z coordinate). Applies depth-based darkening to create a pseudo-3D effect. Sorts primitives by depth for correct occlusion. Args: *primitives: Variable number of primitives to render. padding: Padding around the bounding box in coordinate units. depth: Whether to apply depth-based shading. Returns: SVG string representation of the 3D scene. Raises: ValueError: If no primitives are provided. """ if len(primitives) == 0: msg = "No primitives to draw." raise ValueError(msg) elements: list[svg.Element] = [] bounding_boxes = BoundingBoxes.from_list([pri.bounding_box for pri in primitives]) bounding_box = bounding_boxes.bounding_box min_depth = bounding_box.min[2] max_depth = bounding_box.max[2] def relative_depth(coord: npt.NDArray[np.float32]) -> float: return float((coord[2] - min_depth) / (max_depth - min_depth)) primitives = tuple(sorted(primitives, key=lambda pri: pri.center[2], reverse=True)) depth_factor = DEPTH_DARKEN_FACTOR if depth else 0 for primitive in primitives: if isinstance(primitive, DrawnSphere): elements.append( _render_sphere(primitive, relative_depth=relative_depth(primitive.center), depth_factor=depth_factor) ) elif isinstance(primitive, DrawnLineSegment): elements.extend( _render_line_segment( primitive, relative_depth=relative_depth(primitive.center), depth_factor=depth_factor ) ) elif isinstance(primitive, Text): elements.append( _render_text(primitive, relative_depth=relative_depth(primitive.center), depth_factor=depth_factor) ) canvas = svg.SVG( elements=elements, viewBox=svg.ViewBoxSpec( bounding_box.min[0] - padding, bounding_box.min[1] - padding, bounding_box.size[0] + 2 * padding, bounding_box.size[1] + 2 * padding, ), width=(bounding_box.size[0] + 2 * padding) * SVG_SCALE, height=(bounding_box.size[1] + 2 * padding) * SVG_SCALE, ) return canvas.as_str()
def _render_sphere(primitive: DrawnSphere, /, *, relative_depth: float, depth_factor: float) -> svg.Circle: return svg.Circle( cx=primitive.center[0], cy=primitive.center[1], r=primitive.radius, fill_opacity=primitive.opacity, fill=as_hex_color(darken(primitive.color, relative_depth * depth_factor)) if primitive.color else "transparent", stroke=as_hex_color(darken(primitive.outline_color, relative_depth * depth_factor)) if primitive.outline_color else "transparent", stroke_width=primitive.outline_width, stroke_dasharray=[0.14], ) def _render_line_segment( primitive: DrawnLineSegment, /, *, relative_depth: float, depth_factor: float ) -> list[svg.Line]: center = primitive.center return [ svg.Line( x1=primitive.points[0, 0], x2=center[0], y1=primitive.points[0, 1], y2=center[1], stroke=as_hex_color(darken(primitive.start_color, relative_depth * depth_factor)), stroke_width=primitive.width, ), svg.Line( x1=center[0], x2=primitive.points[1, 0], y1=center[1], y2=primitive.points[1, 1], stroke=as_hex_color(darken(primitive.end_color, relative_depth * depth_factor)), stroke_width=primitive.width, ), ] def _render_text(primitive: Text, /, *, relative_depth: float, depth_factor: float) -> svg.Text: return svg.Text( x=primitive.center[0], y=primitive.center[1], font_size=primitive.font_size, fill=as_hex_color(darken(primitive.color, relative_depth * depth_factor)), text=primitive.text, font_family="Arial, sans-serif", font_weight="bold", text_anchor="middle", dominant_baseline="middle", )