99 lines
3.0 KiB
Python
99 lines
3.0 KiB
Python
"""Bad channel detection using Local Outlier Factor (LOF)."""
|
|
|
|
# Authors: The MNE-Python contributors.
|
|
# License: BSD-3-Clause
|
|
# Copyright the MNE-Python contributors.
|
|
|
|
import numpy as np
|
|
|
|
from .._fiff.pick import _picks_to_idx
|
|
from ..io.base import BaseRaw
|
|
from ..utils import _soft_import, _validate_type, logger, verbose
|
|
|
|
|
|
@verbose
|
|
def find_bad_channels_lof(
|
|
raw,
|
|
n_neighbors=20,
|
|
*,
|
|
picks=None,
|
|
metric="euclidean",
|
|
threshold=1.5,
|
|
return_scores=False,
|
|
verbose=None,
|
|
):
|
|
"""Find bad channels using Local Outlier Factor (LOF) algorithm.
|
|
|
|
Parameters
|
|
----------
|
|
raw : instance of Raw
|
|
Raw data to process.
|
|
n_neighbors : int
|
|
Number of neighbors defining the local neighborhood (default is 20).
|
|
Smaller values will lead to higher LOF scores.
|
|
%(picks_good_data)s
|
|
metric : str
|
|
Metric to use for distance computation. Default is “euclidean”,
|
|
see :func:`sklearn.metrics.pairwise.distance_metrics` for details.
|
|
threshold : float
|
|
Threshold to define outliers. Theoretical threshold ranges anywhere
|
|
between 1.0 and any positive integer. Default: 1.5
|
|
It is recommended to consider this as an hyperparameter to optimize.
|
|
return_scores : bool
|
|
If ``True``, return a dictionary with LOF scores for each
|
|
evaluated channel. Default is ``False``.
|
|
%(verbose)s
|
|
|
|
Returns
|
|
-------
|
|
noisy_chs : list
|
|
List of bad M/EEG channels that were automatically detected.
|
|
scores : ndarray, shape (n_picks,)
|
|
Only returned when ``return_scores`` is ``True``. It contains the
|
|
LOF outlier score for each channel in ``picks``.
|
|
|
|
See Also
|
|
--------
|
|
maxwell_filter
|
|
annotate_amplitude
|
|
|
|
Notes
|
|
-----
|
|
See :footcite:`KumaravelEtAl2022` and :footcite:`BreunigEtAl2000` for background on
|
|
choosing ``threshold``.
|
|
|
|
.. versionadded:: 1.7
|
|
|
|
References
|
|
----------
|
|
.. footbibliography::
|
|
""" # noqa: E501
|
|
_soft_import("sklearn", "using LOF detection", strict=True)
|
|
from sklearn.neighbors import LocalOutlierFactor
|
|
|
|
_validate_type(raw, BaseRaw, "raw")
|
|
# Get the channel types
|
|
channel_types = raw.get_channel_types()
|
|
picks = _picks_to_idx(raw.info, picks=picks, none="data", exclude="bads")
|
|
picked_ch_types = set(channel_types[p] for p in picks)
|
|
|
|
# Check if there are different channel types
|
|
if len(picked_ch_types) != 1:
|
|
raise ValueError(
|
|
f"Need exactly one channel type in picks, got {sorted(picked_ch_types)}"
|
|
)
|
|
ch_names = [raw.ch_names[pick] for pick in picks]
|
|
data = raw.get_data(picks=picks)
|
|
clf = LocalOutlierFactor(n_neighbors=n_neighbors, metric=metric)
|
|
clf.fit_predict(data)
|
|
scores_lof = clf.negative_outlier_factor_
|
|
bad_channel_indices = [
|
|
i for i, v in enumerate(np.abs(scores_lof)) if v >= threshold
|
|
]
|
|
bads = [ch_names[idx] for idx in bad_channel_indices]
|
|
logger.info(f"LOF: Detected bad channel(s): {bads}")
|
|
if return_scores:
|
|
return bads, scores_lof
|
|
else:
|
|
return bads
|