import numpy as np
import numpy.typing as npt
import svg
from boltz_data.draw.color import as_hex_color, darken
from boltz_data.geom import BoundingBoxes
from ._primitive import DrawnLineSegment, DrawnSphere, Primitive, Text
# Rendering constants
DEPTH_DARKEN_FACTOR = 0.6
SVG_SCALE = 15
[docs]
def draw_3d_svg(*primitives: Primitive, padding: float = 2, depth: bool = True) -> str:
"""
Render 3D primitives to an SVG string with depth-based shading.
Projects 3D primitives onto 2D using orthographic projection (dropping Z coordinate).
Applies depth-based darkening to create a pseudo-3D effect. Sorts primitives by
depth for correct occlusion.
Args:
*primitives: Variable number of primitives to render.
padding: Padding around the bounding box in coordinate units.
depth: Whether to apply depth-based shading.
Returns:
SVG string representation of the 3D scene.
Raises:
ValueError: If no primitives are provided.
"""
if len(primitives) == 0:
msg = "No primitives to draw."
raise ValueError(msg)
elements: list[svg.Element] = []
bounding_boxes = BoundingBoxes.from_list([pri.bounding_box for pri in primitives])
bounding_box = bounding_boxes.bounding_box
min_depth = bounding_box.min[2]
max_depth = bounding_box.max[2]
def relative_depth(coord: npt.NDArray[np.float32]) -> float:
return float((coord[2] - min_depth) / (max_depth - min_depth))
primitives = tuple(sorted(primitives, key=lambda pri: pri.center[2], reverse=True))
depth_factor = DEPTH_DARKEN_FACTOR if depth else 0
for primitive in primitives:
if isinstance(primitive, DrawnSphere):
elements.append(
_render_sphere(primitive, relative_depth=relative_depth(primitive.center), depth_factor=depth_factor)
)
elif isinstance(primitive, DrawnLineSegment):
elements.extend(
_render_line_segment(
primitive, relative_depth=relative_depth(primitive.center), depth_factor=depth_factor
)
)
elif isinstance(primitive, Text):
elements.append(
_render_text(primitive, relative_depth=relative_depth(primitive.center), depth_factor=depth_factor)
)
canvas = svg.SVG(
elements=elements,
viewBox=svg.ViewBoxSpec(
bounding_box.min[0] - padding,
bounding_box.min[1] - padding,
bounding_box.size[0] + 2 * padding,
bounding_box.size[1] + 2 * padding,
),
width=(bounding_box.size[0] + 2 * padding) * SVG_SCALE,
height=(bounding_box.size[1] + 2 * padding) * SVG_SCALE,
)
return canvas.as_str()
def _render_sphere(primitive: DrawnSphere, /, *, relative_depth: float, depth_factor: float) -> svg.Circle:
return svg.Circle(
cx=primitive.center[0],
cy=primitive.center[1],
r=primitive.radius,
fill_opacity=primitive.opacity,
fill=as_hex_color(darken(primitive.color, relative_depth * depth_factor)) if primitive.color else "transparent",
stroke=as_hex_color(darken(primitive.outline_color, relative_depth * depth_factor))
if primitive.outline_color
else "transparent",
stroke_width=primitive.outline_width,
stroke_dasharray=[0.14],
)
def _render_line_segment(
primitive: DrawnLineSegment, /, *, relative_depth: float, depth_factor: float
) -> list[svg.Line]:
center = primitive.center
return [
svg.Line(
x1=primitive.points[0, 0],
x2=center[0],
y1=primitive.points[0, 1],
y2=center[1],
stroke=as_hex_color(darken(primitive.start_color, relative_depth * depth_factor)),
stroke_width=primitive.width,
),
svg.Line(
x1=center[0],
x2=primitive.points[1, 0],
y1=center[1],
y2=primitive.points[1, 1],
stroke=as_hex_color(darken(primitive.end_color, relative_depth * depth_factor)),
stroke_width=primitive.width,
),
]
def _render_text(primitive: Text, /, *, relative_depth: float, depth_factor: float) -> svg.Text:
return svg.Text(
x=primitive.center[0],
y=primitive.center[1],
font_size=primitive.font_size,
fill=as_hex_color(darken(primitive.color, relative_depth * depth_factor)),
text=primitive.text,
font_family="Arial, sans-serif",
font_weight="bold",
text_anchor="middle",
dominant_baseline="middle",
)