"""Some utility functions.""" # Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. import logging import os import os.path as op import tempfile import time from collections.abc import Iterable from threading import Thread import numpy as np from ._logging import logger from .check import _check_option from .config import get_config class ProgressBar: """Generate a command-line progressbar. Parameters ---------- iterable : iterable | int | None The iterable to use. Can also be an int for backward compatibility (acts like ``max_value``). initial_value : int Initial value of process, useful when resuming process from a specific value, defaults to 0. mesg : str Message to include at end of progress bar. max_total_width : int | str Maximum total message width. Can use "auto" (default) to try to set a sane value based on the current terminal width. max_value : int | None The max value. If None, the length of ``iterable`` will be used. which_tqdm : str | None Which tqdm module to use. Can be "tqdm", "tqdm.notebook", or "off". Defaults to ``None``, which uses the value of the MNE_TQDM environment variable, or ``"tqdm.auto"`` if that is not set. **kwargs : dict Additional keyword arguments for tqdm. """ def __init__( self, iterable=None, initial_value=0, mesg=None, max_total_width="auto", max_value=None, *, which_tqdm=None, **kwargs, ): # The following mimics this, but with configurable module to use # from ..externals.tqdm import auto import tqdm if which_tqdm is None: which_tqdm = get_config("MNE_TQDM", "tqdm.auto") _check_option( "MNE_TQDM", which_tqdm[:5], ("tqdm", "tqdm.", "off"), extra="beginning" ) logger.debug(f"Using ProgressBar with {which_tqdm}") if which_tqdm not in ("tqdm", "off"): try: __import__(which_tqdm) except Exception as exc: raise ValueError( f"Unknown tqdm backend {repr(which_tqdm)}, got: {exc}" ) from None tqdm = getattr(tqdm, which_tqdm.split(".", 1)[1]) tqdm = tqdm.tqdm defaults = dict( leave=True, mininterval=0.016, miniters=1, smoothing=0.05, bar_format="{percentage:3.0f}%|{bar}| {desc} : {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt:>11}{postfix}]", # noqa: E501 ) for key, val in defaults.items(): if key not in kwargs: kwargs.update({key: val}) if isinstance(iterable, Iterable): self.iterable = iterable if max_value is None: self.max_value = len(iterable) else: self.max_value = max_value else: # ignore max_value then self.max_value = int(iterable) self.iterable = None if max_total_width == "auto": max_total_width = None # tqdm's auto with tempfile.NamedTemporaryFile("wb", prefix="tmp_mne_prog") as tf: self._mmap_fname = tf.name del tf # should remove the file self._mmap = None disable = logger.level > logging.INFO or which_tqdm == "off" self._tqdm = tqdm( iterable=self.iterable, desc=mesg, total=self.max_value, initial=initial_value, ncols=max_total_width, disable=disable, **kwargs, ) def update(self, cur_value): """Update progressbar with current value of process. Parameters ---------- cur_value : number Current value of process. Should be <= max_value (but this is not enforced). The percent of the progressbar will be computed as ``(cur_value / max_value) * 100``. """ self.update_with_increment_value(cur_value - self._tqdm.n) def update_with_increment_value(self, increment_value): """Update progressbar with an increment. Parameters ---------- increment_value : int Value of the increment of process. The percent of the progressbar will be computed as ``(self.cur_value + increment_value / max_value) * 100``. """ try: self._tqdm.update(increment_value) except TypeError: # can happen during GC on Windows pass def __iter__(self): """Iterate to auto-increment the pbar with 1.""" yield from self._tqdm def subset(self, idx): """Make a joblib-friendly index subset updater. Parameters ---------- idx : ndarray List of indices for this subset. Returns ------- updater : instance of PBSubsetUpdater Class with a ``.update(ii)`` method. """ return _PBSubsetUpdater(self, idx) def __enter__(self): # noqa: D105 # This should only be used with pb.subset and parallelization if op.isfile(self._mmap_fname): os.remove(self._mmap_fname) # prevent corner cases where self.max_value == 0 self._mmap = np.memmap( self._mmap_fname, bool, "w+", shape=max(self.max_value, 1) ) self.update(0) # must be zero as we just created the memmap # We need to control how the pickled bars exit: remove print statements self._thread = _UpdateThread(self) self._thread.start() return self def __exit__(self, type_, value, traceback): # noqa: D105 # Restore exit behavior for our one from the main thread self.update(self._mmap.sum()) self._tqdm.close() self._thread._mne_run = False self._thread.join() self._mmap = None if op.isfile(self._mmap_fname): os.remove(self._mmap_fname) def __del__(self): """Ensure output completes.""" if getattr(self, "_tqdm", None) is not None: self._tqdm.close() class _UpdateThread(Thread): def __init__(self, pb): super().__init__(daemon=True) self._mne_run = True self._mne_pb = pb def run(self): while self._mne_run: self._mne_pb.update(self._mne_pb._mmap.sum()) time.sleep(1.0 / 30.0) # 30 Hz refresh is plenty class _PBSubsetUpdater: def __init__(self, pb, idx): self.mmap = pb._mmap self.idx = idx def update(self, ii): self.mmap[self.idx[ii - 1]] = True