281 lines
11 KiB
Python
281 lines
11 KiB
Python
# Authors: The MNE-Python contributors.
|
|
# License: BSD-3-Clause
|
|
# Copyright the MNE-Python contributors.
|
|
|
|
import numpy as np
|
|
|
|
from .._fiff.pick import _picks_by_type, _picks_to_idx
|
|
from ..annotations import (
|
|
Annotations,
|
|
_adjust_onset_meas_date,
|
|
_annotations_starts_stops,
|
|
)
|
|
from ..fixes import jit
|
|
from ..io import BaseRaw
|
|
from ..utils import _mask_to_onsets_offsets, _validate_type, logger, verbose
|
|
|
|
|
|
@verbose
|
|
def annotate_amplitude(
|
|
raw,
|
|
peak=None,
|
|
flat=None,
|
|
bad_percent=5,
|
|
min_duration=0.005,
|
|
picks=None,
|
|
*,
|
|
verbose=None,
|
|
):
|
|
"""Annotate raw data based on peak-to-peak amplitude.
|
|
|
|
Creates annotations ``BAD_peak`` or ``BAD_flat`` for spans of data where
|
|
consecutive samples exceed the threshold in ``peak`` or fall below the
|
|
threshold in ``flat`` for more than ``min_duration``.
|
|
Channels where more than ``bad_percent`` of the total recording length
|
|
should be annotated with either ``BAD_peak`` or ``BAD_flat`` are returned
|
|
in ``bads`` instead.
|
|
Note that the annotations and the bads are not automatically added to the
|
|
:class:`~mne.io.Raw` object; use :meth:`~mne.io.Raw.set_annotations` and
|
|
:class:`info['bads'] <mne.Info>` to do so.
|
|
|
|
Parameters
|
|
----------
|
|
raw : instance of Raw
|
|
The raw data.
|
|
peak : float | dict | None
|
|
Annotate segments based on **maximum** peak-to-peak signal amplitude
|
|
(PTP). Valid **keys** can be any channel type present in the object.
|
|
The **values** are floats that set the maximum acceptable PTP. If the
|
|
PTP is larger than this threshold, the segment will be annotated.
|
|
If float, the minimum acceptable PTP is applied to all channels.
|
|
flat : float | dict | None
|
|
Annotate segments based on **minimum** peak-to-peak signal amplitude
|
|
(PTP). Valid **keys** can be any channel type present in the object.
|
|
The **values** are floats that set the minimum acceptable PTP. If the
|
|
PTP is smaller than this threshold, the segment will be annotated.
|
|
If float, the minimum acceptable PTP is applied to all channels.
|
|
bad_percent : float
|
|
The percentage of the time a channel can be above or below thresholds.
|
|
Below this percentage, :class:`~mne.Annotations` are created.
|
|
Above this percentage, the channel involved is return in ``bads``. Note
|
|
the returned ``bads`` are not automatically added to
|
|
:class:`info['bads'] <mne.Info>`.
|
|
Defaults to ``5``, i.e. 5%%.
|
|
min_duration : float
|
|
The minimum duration (s) required by consecutives samples to be above
|
|
``peak`` or below ``flat`` thresholds to be considered.
|
|
to consider as above or below threshold.
|
|
For some systems, adjacent time samples with exactly the same value are
|
|
not totally uncommon. Defaults to ``0.005`` (5 ms).
|
|
%(picks_good_data)s
|
|
%(verbose)s
|
|
|
|
Returns
|
|
-------
|
|
annotations : instance of Annotations
|
|
The annotated bad segments.
|
|
bads : list
|
|
The channels detected as bad.
|
|
|
|
Notes
|
|
-----
|
|
This function does not use a window to detect small peak-to-peak or large
|
|
peak-to-peak amplitude changes as the ``reject`` and ``flat`` argument from
|
|
:class:`~mne.Epochs` does. Instead, it looks at the difference between
|
|
consecutive samples.
|
|
|
|
- When used to detect segments below ``flat``, at least ``min_duration``
|
|
seconds of consecutive samples must respect
|
|
``abs(a[i+1] - a[i]) ≤ flat``.
|
|
- When used to detect segments above ``peak``, at least ``min_duration``
|
|
seconds of consecutive samples must respect
|
|
``abs(a[i+1] - a[i]) ≥ peak``.
|
|
|
|
Thus, this function does not detect every temporal event with large
|
|
peak-to-peak amplitude, but only the ones where the peak-to-peak amplitude
|
|
is supra-threshold between consecutive samples. For instance, segments
|
|
experiencing a DC shift will not be picked up. Only the edges from the DC
|
|
shift will be annotated (and those only if the edge transitions are longer
|
|
than ``min_duration``).
|
|
|
|
This function may perform faster if data is loaded in memory, as it
|
|
loads data one channel type at a time (across all time points), which is
|
|
typically not an efficient way to read raw data from disk.
|
|
|
|
.. versionadded:: 1.0
|
|
"""
|
|
_validate_type(raw, BaseRaw, "raw")
|
|
picks_ = _picks_to_idx(raw.info, picks, "data_or_ica", exclude="bads")
|
|
peak = _check_ptp(peak, "peak", raw.info, picks_)
|
|
flat = _check_ptp(flat, "flat", raw.info, picks_)
|
|
if peak is None and flat is None:
|
|
raise ValueError(
|
|
"At least one of the arguments 'peak' or 'flat' must not be None."
|
|
)
|
|
bad_percent = _check_bad_percent(bad_percent)
|
|
min_duration = _check_min_duration(
|
|
min_duration, raw.times.size * 1 / raw.info["sfreq"]
|
|
)
|
|
min_duration_samples = int(np.round(min_duration * raw.info["sfreq"]))
|
|
bads = list()
|
|
|
|
# grouping picks by channel types to avoid operating on each channel
|
|
# individually
|
|
picks = {
|
|
ch_type: np.intersect1d(picks_of_type, picks_, assume_unique=True)
|
|
for ch_type, picks_of_type in _picks_by_type(raw.info, exclude="bads")
|
|
if np.intersect1d(picks_of_type, picks_, assume_unique=True).size != 0
|
|
}
|
|
del picks_ # re-using this variable name in for loop
|
|
|
|
# skip BAD_acq_skip sections
|
|
onsets, ends = _annotations_starts_stops(raw, "bad_acq_skip", invert=True)
|
|
index = np.concatenate(
|
|
[np.arange(raw.times.size)[onset:end] for onset, end in zip(onsets, ends)]
|
|
)
|
|
|
|
# size matching the diff a[i+1] - a[i]
|
|
any_flat = np.zeros(len(raw.times) - 1, bool)
|
|
any_peak = np.zeros(len(raw.times) - 1, bool)
|
|
|
|
# look for discrete difference above or below thresholds
|
|
logger.info("Finding segments below or above PTP threshold.")
|
|
for ch_type, picks_ in picks.items():
|
|
data = np.concatenate(
|
|
[raw[picks_, onset:end][0] for onset, end in zip(onsets, ends)], axis=1
|
|
)
|
|
diff = np.abs(np.diff(data, axis=1))
|
|
|
|
if flat is not None:
|
|
flat_ = diff <= flat[ch_type]
|
|
# reject too short segments
|
|
flat_ = _reject_short_segments(flat_, min_duration_samples)
|
|
# reject channels above maximum bad_percentage
|
|
flat_count = flat_.sum(axis=1)
|
|
flat_count[np.nonzero(flat_count)] += 1 # offset by 1 due to diff
|
|
flat_mean = flat_count / raw.times.size * 100
|
|
flat_ch_to_set_bad = picks_[np.where(flat_mean >= bad_percent)[0]]
|
|
bads.extend(flat_ch_to_set_bad)
|
|
# add onset/offset for annotations
|
|
flat_ch_to_annotate = np.where((0 < flat_mean) & (flat_mean < bad_percent))[
|
|
0
|
|
]
|
|
# convert from raw.times[onset:end] - 1 to raw.times[:] - 1
|
|
idx = index[np.where(flat_[flat_ch_to_annotate, :])[1]]
|
|
any_flat[idx] = True
|
|
|
|
if peak is not None:
|
|
peak_ = diff >= peak[ch_type]
|
|
# reject too short segments
|
|
peak_ = _reject_short_segments(peak_, min_duration_samples)
|
|
# reject channels above maximum bad_percentage
|
|
peak_count = peak_.sum(axis=1)
|
|
peak_count[np.nonzero(peak_count)] += 1 # offset by 1 due to diff
|
|
peak_mean = peak_count / raw.times.size * 100
|
|
peak_ch_to_set_bad = picks_[np.where(peak_mean >= bad_percent)[0]]
|
|
bads.extend(peak_ch_to_set_bad)
|
|
# add onset/offset for annotations
|
|
peak_ch_to_annotate = np.where((0 < peak_mean) & (peak_mean < bad_percent))[
|
|
0
|
|
]
|
|
# convert from raw.times[onset:end] - 1 to raw.times[:] - 1
|
|
idx = index[np.where(peak_[peak_ch_to_annotate, :])[1]]
|
|
any_peak[idx] = True
|
|
|
|
# annotation for flat
|
|
annotation_flat = _create_annotations(any_flat, "flat", raw)
|
|
# annotation for peak
|
|
annotation_peak = _create_annotations(any_peak, "peak", raw)
|
|
# group
|
|
annotations = annotation_flat + annotation_peak
|
|
# bads
|
|
bads = [raw.ch_names[bad] for bad in bads if bad not in raw.info["bads"]]
|
|
|
|
return annotations, bads
|
|
|
|
|
|
def _check_ptp(ptp, name, info, picks):
|
|
"""Check the PTP threhsold argument, and converts it to dict if needed."""
|
|
_validate_type(ptp, ("numeric", dict, None))
|
|
|
|
if ptp is not None and not isinstance(ptp, dict):
|
|
if ptp < 0:
|
|
raise ValueError(
|
|
f"Argument '{name}' should define a positive threshold. "
|
|
f"Provided: '{ptp}'."
|
|
)
|
|
ch_types = set(info.get_channel_types(picks))
|
|
ptp = {ch_type: ptp for ch_type in ch_types}
|
|
elif isinstance(ptp, dict):
|
|
for key, value in ptp.items():
|
|
if value < 0:
|
|
raise ValueError(
|
|
f"Argument '{name}' should define positive thresholds. "
|
|
f"Provided for channel type '{key}': '{value}'."
|
|
)
|
|
return ptp
|
|
|
|
|
|
def _check_bad_percent(bad_percent):
|
|
"""Check that bad_percent is a valid percentage and converts to float."""
|
|
_validate_type(bad_percent, "numeric", "bad_percent")
|
|
bad_percent = float(bad_percent)
|
|
if not 0 <= bad_percent <= 100:
|
|
raise ValueError(
|
|
"Argument 'bad_percent' should define a percentage between 0% "
|
|
f"and 100%. Provided: {bad_percent}%."
|
|
)
|
|
return bad_percent
|
|
|
|
|
|
def _check_min_duration(min_duration, raw_duration):
|
|
"""Check that min_duration is a valid duration and converts to float."""
|
|
_validate_type(min_duration, "numeric", "min_duration")
|
|
min_duration = float(min_duration)
|
|
if min_duration < 0:
|
|
raise ValueError(
|
|
"Argument 'min_duration' should define a positive duration in "
|
|
f"seconds. Provided: '{min_duration}' seconds."
|
|
)
|
|
if min_duration >= raw_duration:
|
|
raise ValueError(
|
|
"Argument 'min_duration' should define a positive duration in "
|
|
f"seconds shorter than the raw duration ({raw_duration} seconds). "
|
|
f"Provided: '{min_duration}' seconds."
|
|
)
|
|
return min_duration
|
|
|
|
|
|
def _reject_short_segments(arr, min_duration_samples):
|
|
"""Check if flat or peak segments are longer than the minimum duration."""
|
|
assert arr.dtype == np.dtype(bool) and arr.ndim == 2
|
|
for k, ch in enumerate(arr):
|
|
onsets, offsets = _mask_to_onsets_offsets(ch)
|
|
_mark_inner(arr[k], onsets, offsets, min_duration_samples)
|
|
return arr
|
|
|
|
|
|
@jit()
|
|
def _mark_inner(arr_k, onsets, offsets, min_duration_samples):
|
|
"""Inner loop of _reject_short_segments()."""
|
|
for start, stop in zip(onsets, offsets):
|
|
if stop - start < min_duration_samples:
|
|
arr_k[start:stop] = False
|
|
|
|
|
|
def _create_annotations(any_arr, kind, raw):
|
|
"""Create the peak of flat annotations from the any_arr."""
|
|
assert kind in ("peak", "flat")
|
|
starts, stops = _mask_to_onsets_offsets(any_arr)
|
|
starts, stops = np.array(starts), np.array(stops)
|
|
onsets = starts / raw.info["sfreq"]
|
|
durations = (stops - starts) / raw.info["sfreq"]
|
|
annot = Annotations(
|
|
onsets,
|
|
durations,
|
|
[f"BAD_{kind}"] * len(onsets),
|
|
orig_time=raw.info["meas_date"],
|
|
)
|
|
_adjust_onset_meas_date(annot, raw)
|
|
return annot
|