Source code for boltz_data.parallel._map
"""Parallel mapping utilities using multiprocessing."""
from collections.abc import Callable, Generator, Iterable, Sized
from multiprocessing import Pool, cpu_count
from typing import TypeVar, cast
from tqdm import tqdm # type: ignore[import-untyped]
T = TypeVar("T")
R = TypeVar("R")
def chunk_iterable_from_iterable(iterable: Iterable[T], *, chunk_size: int) -> Iterable[list[T]]:
"""Yield successive n-sized chunks from iterable."""
chunk: list[T] = []
for item in iterable:
chunk.append(item)
if len(chunk) >= chunk_size:
yield chunk
chunk = []
if chunk:
yield chunk
[docs]
def parallel_imap(
func: Callable[[T], R],
iterable: Iterable[T],
/,
*,
n_jobs: int | None = None,
chunk_size: int | None = None,
desc: str | None = None,
) -> Generator[R]:
"""
Apply a function to an iterable in parallel using multiprocessing.
Args:
func: Function to apply to each element
iterable: Iterable of items to process
n_jobs: Number of parallel processes (default: cpu_count())
chunk_size: Size of chunks sent to worker processes (default: automatic)
desc: Description for progress bar
Returns:
Generator of results in the same order as the input iterable
"""
if n_jobs is None:
n_jobs = cpu_count()
elif n_jobs < 1:
msg = f"n_jobs must be positive, got {n_jobs}"
raise ValueError(msg)
if chunk_size is None:
with Pool(processes=n_jobs) as pool:
yield from tqdm(
pool.imap(func, iterable),
total=len(cast("Sized", iterable)) if hasattr(iterable, "__len__") else None,
desc=desc,
)
else:
msg = "chunk_size is not yet implemented"
raise NotImplementedError(msg)