""" Thin wrappers around `concurrent.futures`. """ from __future__ import absolute_import from contextlib import contextmanager from tqdm import TqdmWarning from tqdm.auto import tqdm as tqdm_auto try: from operator import length_hint except ImportError: def length_hint(it, default=0): """Returns `len(it)`, falling back to `default`""" try: return len(it) except TypeError: return default try: from os import cpu_count except ImportError: try: from multiprocessing import cpu_count except ImportError: def cpu_count(): return 4 import sys __author__ = {"github.com/": ["casperdcl"]} __all__ = ['thread_map', 'process_map'] @contextmanager def ensure_lock(tqdm_class, lock_name=""): """get (create if necessary) and then restore `tqdm_class`'s lock""" old_lock = getattr(tqdm_class, '_lock', None) # don't create a new lock lock = old_lock or tqdm_class.get_lock() # maybe create a new lock lock = getattr(lock, lock_name, lock) # maybe subtype tqdm_class.set_lock(lock) yield lock if old_lock is None: del tqdm_class._lock else: tqdm_class.set_lock(old_lock) def _executor_map(PoolExecutor, fn, *iterables, **tqdm_kwargs): """ Implementation of `thread_map` and `process_map`. Parameters ---------- tqdm_class : [default: tqdm.auto.tqdm]. max_workers : [default: min(32, cpu_count() + 4)]. chunksize : [default: 1]. lock_name : [default: "":str]. """ kwargs = tqdm_kwargs.copy() if "total" not in kwargs: kwargs["total"] = len(iterables[0]) tqdm_class = kwargs.pop("tqdm_class", tqdm_auto) max_workers = kwargs.pop("max_workers", min(32, cpu_count() + 4)) chunksize = kwargs.pop("chunksize", 1) lock_name = kwargs.pop("lock_name", "") with ensure_lock(tqdm_class, lock_name=lock_name) as lk: pool_kwargs = {'max_workers': max_workers} sys_version = sys.version_info[:2] if sys_version >= (3, 7): # share lock in case workers are already using `tqdm` pool_kwargs.update(initializer=tqdm_class.set_lock, initargs=(lk,)) map_args = {} if not (3, 0) < sys_version < (3, 5): map_args.update(chunksize=chunksize) with PoolExecutor(**pool_kwargs) as ex: return list(tqdm_class(ex.map(fn, *iterables, **map_args), **kwargs)) def thread_map(fn, *iterables, **tqdm_kwargs): """ Equivalent of `list(map(fn, *iterables))` driven by `concurrent.futures.ThreadPoolExecutor`. Parameters ---------- tqdm_class : optional `tqdm` class to use for bars [default: tqdm.auto.tqdm]. max_workers : int, optional Maximum number of workers to spawn; passed to `concurrent.futures.ThreadPoolExecutor.__init__`. [default: max(32, cpu_count() + 4)]. """ from concurrent.futures import ThreadPoolExecutor return _executor_map(ThreadPoolExecutor, fn, *iterables, **tqdm_kwargs) def process_map(fn, *iterables, **tqdm_kwargs): """ Equivalent of `list(map(fn, *iterables))` driven by `concurrent.futures.ProcessPoolExecutor`. Parameters ---------- tqdm_class : optional `tqdm` class to use for bars [default: tqdm.auto.tqdm]. max_workers : int, optional Maximum number of workers to spawn; passed to `concurrent.futures.ProcessPoolExecutor.__init__`. [default: min(32, cpu_count() + 4)]. chunksize : int, optional Size of chunks sent to worker processes; passed to `concurrent.futures.ProcessPoolExecutor.map`. [default: 1]. lock_name : str, optional Member of `tqdm_class.get_lock()` to use [default: mp_lock]. """ from concurrent.futures import ProcessPoolExecutor if iterables and "chunksize" not in tqdm_kwargs: # default `chunksize=1` has poor performance for large iterables # (most time spent dispatching items to workers). longest_iterable_len = max(map(length_hint, iterables)) if longest_iterable_len > 1000: from warnings import warn warn("Iterable length %d > 1000 but `chunksize` is not set." " This may seriously degrade multiprocess performance." " Set `chunksize=1` or more." % longest_iterable_len, TqdmWarning, stacklevel=2) if "lock_name" not in tqdm_kwargs: tqdm_kwargs = tqdm_kwargs.copy() tqdm_kwargs["lock_name"] = "mp_lock" return _executor_map(ProcessPoolExecutor, fn, *iterables, **tqdm_kwargs)