Source code for boltz_data.draw.color._convert
import re
from collections.abc import Sequence
import numpy as np
import numpy.typing as npt
from ._css import CSS_COLORS
HEX_REGEX = re.compile("^#?[0-9a-fA-F]{6}$")
Color = str | Sequence[int] | Sequence[float] | npt.NDArray[np.float32] | npt.NDArray[np.uint8]
def is_hex_color(color: str, /) -> bool:
return bool(HEX_REGEX.match(color))
[docs]
def as_hex_color(color: Color, /) -> str:
"""
Convert any color format to a hex string.
Args:
color: Color in any supported format:
- CSS color name (e.g., "red", "blue")
- Hex string (e.g., "#ff0000", "ff0000")
- Integer sequence [0-255, 0-255, 0-255]
- Float sequence [0.0-1.0, 0.0-1.0, 0.0-1.0]
- Numpy array (int or float)
Returns:
Hex color string with leading #.
Raises:
ValueError: If color format is invalid.
"""
# Handle string colors
if isinstance(color, str):
color_lower = color.lower()
if color_lower in CSS_COLORS:
return CSS_COLORS[color_lower]
if is_hex_color(color):
return color if color.startswith("#") else f"#{color}"
msg = f"Unrecognized color string: {color}"
raise ValueError(msg)
# Convert to float RGB and then to hex
rgb_float = as_rgb_float_color(color)
return hex_from_rgb(rgb_float)
[docs]
def as_rgb_float_color(color: Color, /) -> npt.NDArray[np.float32]:
"""
Convert any color format to a float RGB numpy array in range [0, 1].
Args:
color: Color in any supported format.
Returns:
Numpy array of shape (3,) with float values in range [0, 1].
Raises:
ValueError: If color format is invalid.
"""
# Handle string colors
if isinstance(color, str):
color_lower = color.lower()
if color_lower in CSS_COLORS:
return rgb_from_hex(CSS_COLORS[color_lower])
if is_hex_color(color):
return rgb_from_hex(color)
msg = f"Unrecognized color string: {color}"
raise ValueError(msg)
# Convert to numpy array if it's a sequence
if not isinstance(color, np.ndarray):
color = np.array(color)
if color.shape != (3,):
msg = f"Color must be a triplet, got shape {color.shape}"
raise ValueError(msg)
# Handle integer arrays (0-255 range)
if color.dtype in (np.int32, np.int64, np.uint8):
if color.max() > 255 or color.min() < 0:
msg = f"Integer RGB values must be in range [0, 255], got [{color.min()}, {color.max()}]"
raise ValueError(msg)
return color.astype(np.float32) / 255.0
# Handle float arrays (0-1 range)
if color.dtype in (np.float32, np.float64):
if color.max() > 1.0 or color.min() < 0.0:
msg = f"Float RGB values must be in range [0.0, 1.0], got [{color.min()}, {color.max()}]"
raise ValueError(msg)
return color.astype(np.float32)
msg = f"Unsupported color dtype: {color.dtype}"
raise ValueError(msg)
[docs]
def as_rgb_int_color(color: Color, /) -> npt.NDArray[np.uint8]:
"""
Convert any color format to an integer RGB numpy array in range [0, 255].
Args:
color: Color in any supported format.
Returns:
Numpy array of shape (3,) with uint8 values in range [0, 255].
Raises:
ValueError: If color format is invalid.
"""
rgb_float = as_rgb_float_color(color)
return (rgb_float * 255).astype(np.uint8)
def hex_from_rgb(rgb: npt.NDArray[np.float32] | npt.NDArray[np.uint8]) -> str:
if rgb.shape != (3,):
msg = f"Cannot interpret numpy array of shape {rgb.shape} as RGB triplet"
raise ValueError(msg)
match rgb.dtype:
case np.uint8:
if rgb.max() > 255 or rgb.min() < 0:
msg = f"RGB values must be in range [0, 255], got [{rgb.min()}, {rgb.max()}]"
raise ValueError(msg)
rgb = np.array(rgb, dtype=np.uint8)
case np.float32 | np.float64:
rgb = np.array(rgb * 255, dtype=np.uint8)
case _:
msg = f"Cannot interpret numpy array of dtype {rgb.dtype} as RGB triplet"
raise ValueError(msg)
return "#{:02x}{:02x}{:02x}".format(*tuple(rgb.tolist()))
def rgb_from_hex(hex_str: str, /) -> npt.NDArray[np.float32]:
hex_str = hex_str.lstrip("#")
return np.array([int(hex_str[i : i + 2], 16) / 255 for i in (0, 2, 4)])