# Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. import numpy as np from ..._fiff.constants import FIFF from ...annotations import _annotations_starts_stops from ...io import BaseRaw from ...utils import _check_preload, _validate_type, logger, warn def interpolate_blinks(raw, buffer=0.05, match="BAD_blink", interpolate_gaze=False): """Interpolate eyetracking signals during blinks. This function uses the timing of blink annotations to estimate missing data. Missing values are then interpolated linearly. Operates in place. Parameters ---------- raw : instance of Raw The raw data with at least one ``'pupil'`` or ``'eyegaze'`` channel. buffer : float | array-like of float, shape ``(2,))`` The time in seconds before and after a blink to consider invalid and include in the segment to be interpolated over. Default is ``0.05`` seconds (50 ms). If array-like, the first element is the time before the blink and the second element is the time after the blink to consider invalid, for example, ``(0.025, .1)``. match : str | list of str The description of annotations to interpolate over. If a list, the data within all annotations that match any of the strings in the list will be interpolated over. If a ``match`` starts with ``'BAD_'``, that part will be removed from the annotation description after interpolation. Defaults to ``'BAD_blink'``. interpolate_gaze : bool If False, only apply interpolation to ``'pupil channels'``. If True, interpolate over ``'eyegaze'`` channels as well. Defaults to False, because eye position can change in unpredictable ways during blinks. Returns ------- self : instance of Raw Returns the modified instance. Notes ----- .. versionadded:: 1.5 """ _check_preload(raw, "interpolate_blinks") _validate_type(raw, BaseRaw, "raw") _validate_type(buffer, (float, tuple, list, np.ndarray), "buffer") _validate_type(match, (str, tuple, list, np.ndarray), "match") # determine the buffer around blinks to include in the interpolation buffer = np.array(buffer, dtype=float) if buffer.size == 1: buffer = np.array([buffer, buffer]) if isinstance(match, str): match = [match] # get the blink annotations blink_annots = [annot for annot in raw.annotations if annot["description"] in match] if not blink_annots: warn(f"No annotations matching {match} found. Aborting.") return raw _interpolate_blinks(raw, buffer, blink_annots, interpolate_gaze=interpolate_gaze) # remove bad from the annotation description for desc in match: if desc.startswith("BAD_"): logger.info(f"Removing 'BAD_' from {desc}.") raw.annotations.rename({desc: desc.replace("BAD_", "")}) return raw def _interpolate_blinks(raw, buffer, blink_annots, interpolate_gaze): """Interpolate eyetracking signals during blinks in-place.""" logger.info("Interpolating missing data during blinks...") pre_buffer, post_buffer = buffer # iterate over each eyetrack channel and interpolate the blinks interpolated_chs = [] for ci, ch_info in enumerate(raw.info["chs"]): if interpolate_gaze: # interpolate over all eyetrack channels if ch_info["kind"] != FIFF.FIFFV_EYETRACK_CH: continue else: # interpolate over pupil channels only if ch_info["coil_type"] != FIFF.FIFFV_COIL_EYETRACK_PUPIL: continue # Create an empty boolean mask mask = np.zeros_like(raw.times, dtype=bool) starts, ends = _annotations_starts_stops(raw, "BAD_blink") starts = np.divide(starts, raw.info["sfreq"]) ends = np.divide(ends, raw.info["sfreq"]) for annot, start, end in zip(blink_annots, starts, ends): if "ch_names" not in annot or not annot["ch_names"]: msg = f"Blink annotation missing values for 'ch_names' key: {annot}" raise ValueError(msg) start -= pre_buffer end += post_buffer if ch_info["ch_name"] not in annot["ch_names"]: continue # skip if the channel is not in the blink annotation # Update the mask for times within the current blink period mask |= (raw.times >= start) & (raw.times <= end) blink_indices = np.where(mask)[0] non_blink_indices = np.where(~mask)[0] # Linear interpolation interpolated_samples = np.interp( raw.times[blink_indices], raw.times[non_blink_indices], raw._data[ci, non_blink_indices], ) # Replace the samples at the blink_indices with the interpolated values raw._data[ci, blink_indices] = interpolated_samples interpolated_chs.append(ch_info["ch_name"]) if interpolated_chs: logger.info( f"Interpolated {len(interpolated_chs)} channels: {interpolated_chs}" ) else: warn("No channels were interpolated.")