Source code for boltz_data.parallel._map

"""Parallel mapping utilities using multiprocessing."""

from collections.abc import Callable, Generator, Iterable, Sized
from multiprocessing import Pool, cpu_count
from typing import TypeVar, cast

from tqdm import tqdm  # type: ignore[import-untyped]

T = TypeVar("T")
R = TypeVar("R")


def chunk_iterable_from_iterable(iterable: Iterable[T], *, chunk_size: int) -> Iterable[list[T]]:
    """Yield successive n-sized chunks from iterable."""
    chunk: list[T] = []
    for item in iterable:
        chunk.append(item)
        if len(chunk) >= chunk_size:
            yield chunk
            chunk = []
    if chunk:
        yield chunk


[docs] def parallel_imap( func: Callable[[T], R], iterable: Iterable[T], /, *, n_jobs: int | None = None, chunk_size: int | None = None, desc: str | None = None, ) -> Generator[R]: """ Apply a function to an iterable in parallel using multiprocessing. Args: func: Function to apply to each element iterable: Iterable of items to process n_jobs: Number of parallel processes (default: cpu_count()) chunk_size: Size of chunks sent to worker processes (default: automatic) desc: Description for progress bar Returns: Generator of results in the same order as the input iterable """ if n_jobs is None: n_jobs = cpu_count() elif n_jobs < 1: msg = f"n_jobs must be positive, got {n_jobs}" raise ValueError(msg) if chunk_size is None: with Pool(processes=n_jobs) as pool: yield from tqdm( pool.imap(func, iterable), total=len(cast("Sized", iterable)) if hasattr(iterable, "__len__") else None, desc=desc, ) else: msg = "chunk_size is not yet implemented" raise NotImplementedError(msg)