# Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. import numpy as np from scipy.signal import get_window from .utils import _ensure_int, logger, verbose ############################################################################### # Class for interpolation between adjacent points class _Interp2: r"""Interpolate between two points. Parameters ---------- control_points : array, shape (n_changes,) The control points (indices) to use. values : callable | array, shape (n_changes, ...) Callable that takes the control point and returns a list of arrays that must be interpolated. interp : str Can be 'zero', 'linear', 'hann', or 'cos2' (same as hann). Notes ----- This will process data using overlapping windows of potentially different sizes to achieve a constant output value using different 2-point interpolation schemes. For example, for linear interpolation, and window sizes of 6 and 17, this would look like:: 1 _ _ |\ / '-. .-' | \ / '-. .-' | x |-.-| | / \ .-' '-. |/ \_.-' '-. 0 +----|----|----|----|--- 0 5 10 15 20 25 """ def __init__(self, control_points, values, interp="hann"): # set up interpolation self.control_points = np.array(control_points, int).ravel() if not np.array_equal(np.unique(self.control_points), self.control_points): raise ValueError("Control points must be sorted and unique") if len(self.control_points) == 0: raise ValueError("Must be at least one control point") if not (self.control_points >= 0).all(): raise ValueError( f"All control points must be positive (got {self.control_points[:3]})" ) if isinstance(values, np.ndarray): values = [values] if isinstance(values, (list, tuple)): for v in values: if not (v is None or isinstance(v, np.ndarray)): raise TypeError( 'All entries in "values" must be ndarray or None, got ' f"{type(v)}" ) if v is not None and v.shape[0] != len(self.control_points): raise ValueError( "Values, if provided, must be the same length as the number of " f"control points ({len(self.control_points)}), got {v.shape[0]}" ) use_values = values def val(pt): idx = np.where(control_points == pt)[0][0] return [v[idx] if v is not None else None for v in use_values] values = val self.values = values self.n_last = None self._position = 0 # start at zero self._left_idx = 0 self._left = self._right = self._use_interp = None known_types = ("cos2", "linear", "zero", "hann") if interp not in known_types: raise ValueError(f'interp must be one of {known_types}, got "{interp}"') self._interp = interp def feed_generator(self, n_pts): """Feed data and get interpolators as a generator.""" self.n_last = 0 n_pts = _ensure_int(n_pts, "n_pts") original_position = self._position stop = self._position + n_pts logger.debug(f"Feed {n_pts} ({self._position}-{stop})") used = np.zeros(n_pts, bool) if self._left is None: # first one logger.debug(f" Eval @ 0 ({self.control_points[0]})") self._left = self.values(self.control_points[0]) if len(self.control_points) == 1: self._right = self._left n_used = 0 # Left zero-order hold condition if self._position < self.control_points[self._left_idx]: n_use = min(self.control_points[self._left_idx] - self._position, n_pts) logger.debug(f" Left ZOH {n_use}") this_sl = slice(None, n_use) assert used[this_sl].size == n_use assert not used[this_sl].any() used[this_sl] = True yield [this_sl, self._left, None, None] self._position += n_use n_used += n_use self.n_last += 1 # Standard interpolation condition stop_right_idx = np.where(self.control_points >= stop)[0] if len(stop_right_idx) == 0: stop_right_idx = [len(self.control_points) - 1] stop_right_idx = stop_right_idx[0] left_idxs = np.arange(self._left_idx, stop_right_idx) self.n_last += max(len(left_idxs) - 1, 0) for bi, left_idx in enumerate(left_idxs): if left_idx != self._left_idx or self._right is None: if self._right is not None: assert left_idx == self._left_idx + 1 self._left = self._right self._left_idx += 1 self._use_interp = None # need to recreate it eval_pt = self.control_points[self._left_idx + 1] logger.debug(f" Eval @ {self._left_idx + 1} ({eval_pt})") self._right = self.values(eval_pt) assert self._right is not None left_point = self.control_points[self._left_idx] right_point = self.control_points[self._left_idx + 1] if self._use_interp is None: interp_span = right_point - left_point if self._interp == "zero": self._use_interp = None elif self._interp == "linear": self._use_interp = np.linspace( 1.0, 0.0, interp_span, endpoint=False ) else: # self._interp in ('cos2', 'hann'): self._use_interp = np.cos( np.linspace(0, np.pi / 2.0, interp_span, endpoint=False) ) self._use_interp *= self._use_interp n_use = min(stop, right_point) - self._position if n_use > 0: logger.debug( f" Interp {self._interp} {n_use} ({left_point}-{right_point})" ) interp_start = self._position - left_point assert interp_start >= 0 if self._use_interp is None: this_interp = None else: this_interp = self._use_interp[interp_start : interp_start + n_use] assert this_interp.size == n_use this_sl = slice(n_used, n_used + n_use) assert used[this_sl].size == n_use assert not used[this_sl].any() used[this_sl] = True yield [this_sl, self._left, self._right, this_interp] self._position += n_use n_used += n_use # Right zero-order hold condition if self.control_points[self._left_idx] <= self._position: n_use = stop - self._position if n_use > 0: logger.debug(f" Right ZOH {n_use}") this_sl = slice(n_pts - n_use, None) assert not used[this_sl].any() used[this_sl] = True assert self._right is not None yield [this_sl, self._right, None, None] self._position += n_use n_used += n_use self.n_last += 1 assert self._position == stop assert n_used == n_pts assert used.all() assert self._position == original_position + n_pts def feed(self, n_pts): """Feed data and get interpolated values.""" # Convenience function for assembly out_arrays = None for o in self.feed_generator(n_pts): if out_arrays is None: out_arrays = [ np.empty(v.shape + (n_pts,)) if v is not None else None for v in o[1] ] for ai, arr in enumerate(out_arrays): if arr is not None: if o[3] is None: arr[..., o[0]] = o[1][ai][..., np.newaxis] else: arr[..., o[0]] = o[1][ai][..., np.newaxis] * o[3] + o[2][ai][ ..., np.newaxis ] * (1.0 - o[3]) assert out_arrays is not None return out_arrays ############################################################################### # Constant overlap-add processing class def _check_store(store): if isinstance(store, np.ndarray): store = [store] if isinstance(store, (list, tuple)) and all( isinstance(s, np.ndarray) for s in store ): store = _Storer(*store) if not callable(store): raise TypeError(f"store must be callable, got type {type(store)}") return store class _COLA: r"""Constant overlap-add processing helper. Parameters ---------- process : callable A function that takes a chunk of input data with shape ``(n_channels, n_samples)`` and processes it. store : callable | ndarray A function that takes a completed chunk of output data. Can also be an ``ndarray``, in which case it is treated as the output data in which to store the results. n_total : int The total number of samples. n_samples : int The number of samples per window. n_overlap : int The overlap between windows. window : str The window to use. Default is "hann". tol : float The tolerance for COLA checking. Notes ----- This will process data using overlapping windows to achieve a constant output value. For example, for ``n_total=27``, ``n_samples=10``, ``n_overlap=5`` and ``window='triang'``:: 1 _____ _______ | \ /\ /\ / | \ / \ / \ / | x x x | / \ / \ / \ | / \/ \/ \ 0 +----|----|----|----|----|- 0 5 10 15 20 25 This produces four windows: the first three are the requested length (10 samples) and the last one is longer (12 samples). The first and last window are asymmetric. """ @verbose def __init__( self, process, store, n_total, n_samples, n_overlap, sfreq, window="hann", tol=1e-10, *, verbose=None, ): n_samples = _ensure_int(n_samples, "n_samples") n_overlap = _ensure_int(n_overlap, "n_overlap") n_total = _ensure_int(n_total, "n_total") if n_samples <= 0: raise ValueError(f"n_samples must be > 0, got {n_samples}") if n_overlap < 0: raise ValueError(f"n_overlap must be >= 0, got {n_overlap}") if n_total < 0: raise ValueError(f"n_total must be >= 0, got {n_total}") self._n_samples = int(n_samples) self._n_overlap = int(n_overlap) del n_samples, n_overlap if n_total < self._n_samples: raise ValueError( f"Number of samples per window ({self._n_samples}) must be at " f"most the total number of samples ({n_total})" ) if not callable(process): raise TypeError(f"process must be callable, got type {type(process)}") self._process = process self._step = self._n_samples - self._n_overlap self._store = _check_store(store) self._idx = 0 self._in_buffers = self._out_buffers = None # Create our window boundaries window_name = window if isinstance(window, str) else "custom" self._window = get_window( window, self._n_samples, fftbins=(self._n_samples - 1) % 2 ) self._window /= _check_cola( self._window, self._n_samples, self._step, window_name, tol=tol ) self.starts = np.arange(0, n_total - self._n_samples + 1, self._step) self.stops = self.starts + self._n_samples delta = n_total - self.stops[-1] self.stops[-1] = n_total sfreq = float(sfreq) pl = "s" if len(self.starts) != 1 else "" logger.info( " Processing %4d data chunk%s of (at least) %0.1f s " "with %0.1f s overlap and %s windowing" % ( len(self.starts), pl, self._n_samples / sfreq, self._n_overlap / sfreq, window_name, ) ) del window, window_name if delta > 0: logger.info( f" The final {delta / sfreq} s will be lumped into the final window" ) @property def _in_offset(self): """Compute from current processing window start and buffer len.""" return self.starts[self._idx] + self._in_buffers[0].shape[-1] @verbose def feed(self, *datas, verbose=None, **kwargs): """Pass in a chunk of data.""" # Append to our input buffer if self._in_buffers is None: self._in_buffers = [None] * len(datas) if len(datas) != len(self._in_buffers): raise ValueError( f"Got {len(datas)} array(s), needed {len(self._in_buffers)}" ) for di, data in enumerate(datas): if not isinstance(data, np.ndarray) or data.ndim < 1: raise TypeError( f"data entry {di} must be an 2D ndarray, got {type(data)}" ) if self._in_buffers[di] is None: # In practice, users can give large chunks, so we use # dynamic allocation of the in buffer. We could save some # memory allocation by only ever processing max_len at once, # but this would increase code complexity. self._in_buffers[di] = np.empty(data.shape[:-1] + (0,), data.dtype) if ( data.shape[:-1] != self._in_buffers[di].shape[:-1] or self._in_buffers[di].dtype != data.dtype ): raise TypeError( f"data must dtype {self._in_buffers[di].dtype} and " f"shape[:-1]=={self._in_buffers[di].shape[:-1]}, got dtype " f"{data.dtype} shape[:-1]={data.shape[:-1]}" ) logger.debug( f" + Appending {self._in_offset:d}->" f"{self._in_offset + data.shape[-1]:d}" ) self._in_buffers[di] = np.concatenate([self._in_buffers[di], data], -1) if self._in_offset > self.stops[-1]: raise ValueError( f"data (shape {data.shape}) exceeded expected total buffer size (" f"{self._in_offset} > {self.stops[-1]})" ) # Check to see if we can process the next chunk and dump outputs while self._idx < len(self.starts) and self._in_offset >= self.stops[self._idx]: start, stop = self.starts[self._idx], self.stops[self._idx] this_len = stop - start this_window = self._window.copy() if self._idx == len(self.starts) - 1: this_window = np.pad( self._window, (0, this_len - len(this_window)), "constant" ) for offset in range(self._step, len(this_window), self._step): n_use = len(this_window) - offset this_window[offset:] += self._window[:n_use] if self._idx == 0: for offset in range(self._n_samples - self._step, 0, -self._step): this_window[:offset] += self._window[-offset:] logger.debug(f" * Processing {start}->{stop}") this_proc = [in_[..., :this_len].copy() for in_ in self._in_buffers] if not all( proc.shape[-1] == this_len == this_window.size for proc in this_proc ): raise RuntimeError("internal indexing error") outs = self._process(*this_proc, **kwargs) if self._out_buffers is None: max_len = np.max(self.stops - self.starts) self._out_buffers = [ np.zeros(o.shape[:-1] + (max_len,), o.dtype) for o in outs ] for oi, out in enumerate(outs): out *= this_window self._out_buffers[oi][..., : stop - start] += out self._idx += 1 if self._idx < len(self.starts): next_start = self.starts[self._idx] else: next_start = self.stops[-1] delta = next_start - self.starts[self._idx - 1] for di in range(len(self._in_buffers)): self._in_buffers[di] = self._in_buffers[di][..., delta:] logger.debug(f" - Shifting input/output buffers by {delta:d} samples") self._store(*[o[..., :delta] for o in self._out_buffers]) for ob in self._out_buffers: ob[..., :-delta] = ob[..., delta:] ob[..., -delta:] = 0.0 def _check_cola(win, nperseg, step, window_name, tol=1e-10): """Check whether the Constant OverLap Add (COLA) constraint is met.""" # adapted from SciPy binsums = np.sum( [win[ii * step : (ii + 1) * step] for ii in range(nperseg // step)], axis=0 ) if nperseg % step != 0: binsums[: nperseg % step] += win[-(nperseg % step) :] const = np.median(binsums) deviation = np.max(np.abs(binsums - const)) if deviation > tol: raise ValueError( f"segment length {nperseg:d} with step {step:d} for {window_name} window " "type does not provide a constant output " f"({100 * deviation / const:g}% deviation)" ) return const class _Storer: """Store data in chunks.""" def __init__(self, *outs, picks=None): for oi, out in enumerate(outs): if not isinstance(out, np.ndarray) or out.ndim < 1: raise TypeError(f"outs[oi] must be >= 1D ndarray, got {out}") self.outs = outs self.idx = 0 self.picks = picks def __call__(self, *outs): if len(outs) != len(self.outs) or not all( out.shape[-1] == outs[0].shape[-1] for out in outs ): raise ValueError("Bad outs") idx = (Ellipsis,) if self.picks is not None: idx += (self.picks,) stop = self.idx + outs[0].shape[-1] idx += (slice(self.idx, stop),) for o1, o2 in zip(self.outs, outs): o1[idx] = o2 self.idx = stop