"""Utility functions to baseline-correct data.""" # Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. import numpy as np from .utils import _check_option, _validate_type, logger, verbose def _log_rescale(baseline, mode="mean"): """Log the rescaling method.""" if baseline is not None: _check_option( "mode", mode, ["logratio", "ratio", "zscore", "mean", "percent", "zlogratio"], ) msg = f"Applying baseline correction (mode: {mode})" else: msg = "No baseline correction applied" return msg @verbose def rescale(data, times, baseline, mode="mean", copy=True, picks=None, verbose=None): """Rescale (baseline correct) data. Parameters ---------- data : array It can be of any shape. The only constraint is that the last dimension should be time. times : 1D array Time instants is seconds. %(baseline_rescale)s mode : 'mean' | 'ratio' | 'logratio' | 'percent' | 'zscore' | 'zlogratio' Perform baseline correction by - subtracting the mean of baseline values ('mean') - dividing by the mean of baseline values ('ratio') - dividing by the mean of baseline values and taking the log ('logratio') - subtracting the mean of baseline values followed by dividing by the mean of baseline values ('percent') - subtracting the mean of baseline values and dividing by the standard deviation of baseline values ('zscore') - dividing by the mean of baseline values, taking the log, and dividing by the standard deviation of log baseline values ('zlogratio') copy : bool Whether to return a new instance or modify in place. picks : list of int | None Data to process along the axis=-2 (None, default, processes all). %(verbose)s Returns ------- data_scaled: array Array of same shape as data after rescaling. """ if copy: data = data.copy() if verbose is not False: msg = _log_rescale(baseline, mode) logger.info(msg) if baseline is None or data.shape[-1] == 0: return data bmin, bmax = baseline if bmin is None: imin = 0 else: imin = np.where(times >= bmin)[0] if len(imin) == 0: raise ValueError( f"bmin is too large ({bmin}), it exceeds the largest time value" ) imin = int(imin[0]) if bmax is None: imax = len(times) else: imax = np.where(times <= bmax)[0] if len(imax) == 0: raise ValueError( f"bmax is too small ({bmax}), it is smaller than the smallest time " "value" ) imax = int(imax[-1]) + 1 if imin >= imax: raise ValueError( f"Bad rescaling slice ({imin}:{imax}) from time values {bmin}, {bmax}" ) # technically this is inefficient when `picks` is given, but assuming # that we generally pick most channels for rescaling, it's not so bad mean = np.mean(data[..., imin:imax], axis=-1, keepdims=True) if mode == "mean": def fun(d, m): d -= m elif mode == "ratio": def fun(d, m): d /= m elif mode == "logratio": def fun(d, m): d /= m np.log10(d, out=d) elif mode == "percent": def fun(d, m): d -= m d /= m elif mode == "zscore": def fun(d, m): d -= m d /= np.std(d[..., imin:imax], axis=-1, keepdims=True) elif mode == "zlogratio": def fun(d, m): d /= m np.log10(d, out=d) d /= np.std(d[..., imin:imax], axis=-1, keepdims=True) if picks is None: fun(data, mean) else: for pi in picks: fun(data[..., pi, :], mean[..., pi, :]) return data def _check_baseline(baseline, times, sfreq, on_baseline_outside_data="raise"): """Check if the baseline is valid and adjust it if requested. ``None`` values inside ``baseline`` will be replaced with ``times[0]`` and ``times[-1]``. Parameters ---------- baseline : array-like, shape (2,) | None Beginning and end of the baseline period, in seconds. If ``None``, assume no baseline and return immediately. times : array The time points. sfreq : float The sampling rate. on_baseline_outside_data : 'raise' | 'info' | 'adjust' What to do if the baseline period exceeds the data. If ``'raise'``, raise an exception (default). If ``'info'``, log an info message. If ``'adjust'``, adjust the baseline such that it is within the data range. Returns ------- (baseline_tmin, baseline_tmax) | None The baseline with ``None`` values replaced with times, and with adjusted times if ``on_baseline_outside_data='adjust'``; or ``None``, if ``baseline`` is ``None``. """ if baseline is None: return None _validate_type(baseline, "array-like") baseline = tuple(baseline) if len(baseline) != 2: raise ValueError( f"baseline must have exactly two elements (got {len(baseline)})." ) tmin, tmax = times[0], times[-1] tstep = 1.0 / float(sfreq) # check default value of baseline and `tmin=0` if baseline == (None, 0) and tmin == 0: raise ValueError( "Baseline interval is only one sample. Use `baseline=(0, 0)` if this is " "desired." ) baseline_tmin, baseline_tmax = baseline if baseline_tmin is None: baseline_tmin = tmin baseline_tmin = float(baseline_tmin) if baseline_tmax is None: baseline_tmax = tmax baseline_tmax = float(baseline_tmax) if baseline_tmin > baseline_tmax: raise ValueError( f"Baseline min ({baseline_tmin}) must be less than baseline max (" f"{baseline_tmax})" ) if (baseline_tmin < tmin - tstep) or (baseline_tmax > tmax + tstep): msg = ( f"Baseline interval [{baseline_tmin}, {baseline_tmax}] s is outside of " f"epochs data [{tmin}, {tmax}] s. Epochs were probably cropped." ) if on_baseline_outside_data == "raise": raise ValueError(msg) elif on_baseline_outside_data == "info": logger.info(msg) elif on_baseline_outside_data == "adjust": if baseline_tmin < tmin - tstep: baseline_tmin = tmin if baseline_tmax > tmax + tstep: baseline_tmax = tmax return baseline_tmin, baseline_tmax