222 lines
7.3 KiB
Python
222 lines
7.3 KiB
Python
# Authors: The MNE-Python contributors.
|
|
# License: BSD-3-Clause
|
|
# Copyright the MNE-Python contributors.
|
|
|
|
from collections import Counter
|
|
|
|
import numpy as np
|
|
|
|
from .._fiff.pick import _picks_to_idx, pick_info, pick_types
|
|
from ..parallel import parallel_func
|
|
from ..utils import logger, verbose
|
|
from .base import _set_cv
|
|
from .mixin import EstimatorMixin, TransformerMixin
|
|
|
|
|
|
class EMS(TransformerMixin, EstimatorMixin):
|
|
"""Transformer to compute event-matched spatial filters.
|
|
|
|
This version of EMS :footcite:`SchurgerEtAl2013` operates on the entire
|
|
time course. No time
|
|
window needs to be specified. The result is a spatial filter at each
|
|
time point and a corresponding time course. Intuitively, the result
|
|
gives the similarity between the filter at each time point and the
|
|
data vector (sensors) at that time point.
|
|
|
|
.. note:: EMS only works for binary classification.
|
|
|
|
Attributes
|
|
----------
|
|
filters_ : ndarray, shape (n_channels, n_times)
|
|
The set of spatial filters.
|
|
classes_ : ndarray, shape (n_classes,)
|
|
The target classes.
|
|
|
|
References
|
|
----------
|
|
.. footbibliography::
|
|
"""
|
|
|
|
def __repr__(self): # noqa: D105
|
|
if hasattr(self, "filters_"):
|
|
return (
|
|
f"<EMS: fitted with {len(self.filters_)} filters "
|
|
f"on {len(self.classes_)} classes.>"
|
|
)
|
|
else:
|
|
return "<EMS: not fitted.>"
|
|
|
|
def fit(self, X, y):
|
|
"""Fit the spatial filters.
|
|
|
|
.. note : EMS is fitted on data normalized by channel type before the
|
|
fitting of the spatial filters.
|
|
|
|
Parameters
|
|
----------
|
|
X : array, shape (n_epochs, n_channels, n_times)
|
|
The training data.
|
|
y : array of int, shape (n_epochs)
|
|
The target classes.
|
|
|
|
Returns
|
|
-------
|
|
self : instance of EMS
|
|
Returns self.
|
|
"""
|
|
classes = np.unique(y)
|
|
if len(classes) != 2:
|
|
raise ValueError("EMS only works for binary classification.")
|
|
self.classes_ = classes
|
|
filters = X[y == classes[0]].mean(0) - X[y == classes[1]].mean(0)
|
|
filters /= np.linalg.norm(filters, axis=0)[None, :]
|
|
self.filters_ = filters
|
|
return self
|
|
|
|
def transform(self, X):
|
|
"""Transform the data by the spatial filters.
|
|
|
|
Parameters
|
|
----------
|
|
X : array, shape (n_epochs, n_channels, n_times)
|
|
The input data.
|
|
|
|
Returns
|
|
-------
|
|
X : array, shape (n_epochs, n_times)
|
|
The input data transformed by the spatial filters.
|
|
"""
|
|
Xt = np.sum(X * self.filters_, axis=1)
|
|
return Xt
|
|
|
|
|
|
@verbose
|
|
def compute_ems(
|
|
epochs, conditions=None, picks=None, n_jobs=None, cv=None, verbose=None
|
|
):
|
|
"""Compute event-matched spatial filter on epochs.
|
|
|
|
This version of EMS :footcite:`SchurgerEtAl2013` operates on the entire
|
|
time course. No time
|
|
window needs to be specified. The result is a spatial filter at each
|
|
time point and a corresponding time course. Intuitively, the result
|
|
gives the similarity between the filter at each time point and the
|
|
data vector (sensors) at that time point.
|
|
|
|
.. note : EMS only works for binary classification.
|
|
|
|
.. note : The present function applies a leave-one-out cross-validation,
|
|
following Schurger et al's paper. However, we recommend using
|
|
a stratified k-fold cross-validation. Indeed, leave-one-out tends
|
|
to overfit and cannot be used to estimate the variance of the
|
|
prediction within a given fold.
|
|
|
|
.. note : Because of the leave-one-out, this function needs an equal
|
|
number of epochs in each of the two conditions.
|
|
|
|
Parameters
|
|
----------
|
|
epochs : instance of mne.Epochs
|
|
The epochs.
|
|
conditions : list of str | None, default None
|
|
If a list of strings, strings must match the epochs.event_id's key as
|
|
well as the number of conditions supported by the objective_function.
|
|
If None keys in epochs.event_id are used.
|
|
%(picks_good_data)s
|
|
%(n_jobs)s
|
|
cv : cross-validation object | str | None, default LeaveOneOut
|
|
The cross-validation scheme.
|
|
%(verbose)s
|
|
|
|
Returns
|
|
-------
|
|
surrogate_trials : ndarray, shape (n_trials // 2, n_times)
|
|
The trial surrogates.
|
|
mean_spatial_filter : ndarray, shape (n_channels, n_times)
|
|
The set of spatial filters.
|
|
conditions : ndarray, shape (n_classes,)
|
|
The conditions used. Values correspond to original event ids.
|
|
|
|
References
|
|
----------
|
|
.. footbibliography::
|
|
"""
|
|
logger.info("...computing surrogate time series. This can take some time")
|
|
|
|
# Default to leave-one-out cv
|
|
cv = "LeaveOneOut" if cv is None else cv
|
|
picks = _picks_to_idx(epochs.info, picks)
|
|
|
|
if not len(set(Counter(epochs.events[:, 2]).values())) == 1:
|
|
raise ValueError(
|
|
"The same number of epochs is required by "
|
|
"this function. Please consider "
|
|
"`epochs.equalize_event_counts`"
|
|
)
|
|
|
|
if conditions is None:
|
|
conditions = epochs.event_id.keys()
|
|
epochs = epochs.copy()
|
|
else:
|
|
epochs = epochs[conditions]
|
|
|
|
epochs.drop_bad()
|
|
|
|
if len(conditions) != 2:
|
|
raise ValueError(
|
|
"Currently this function expects exactly 2 "
|
|
f"conditions but you gave me {len(conditions)}"
|
|
)
|
|
|
|
ev = epochs.events[:, 2]
|
|
# Special care to avoid path dependent mappings and orders
|
|
conditions = list(sorted(conditions))
|
|
cond_idx = [np.where(ev == epochs.event_id[k])[0] for k in conditions]
|
|
|
|
info = pick_info(epochs.info, picks)
|
|
data = epochs.get_data(picks=picks)
|
|
|
|
# Scale (z-score) the data by channel type
|
|
# XXX the z-scoring is applied outside the CV, which is not standard.
|
|
for ch_type in ["mag", "grad", "eeg"]:
|
|
if ch_type in epochs:
|
|
# FIXME should be applied to all sort of data channels
|
|
if ch_type == "eeg":
|
|
this_picks = pick_types(info, meg=False, eeg=True)
|
|
else:
|
|
this_picks = pick_types(info, meg=ch_type, eeg=False)
|
|
data[:, this_picks] /= np.std(data[:, this_picks])
|
|
|
|
# Setup cross-validation. Need to use _set_cv to deal with sklearn
|
|
# deprecation of cv objects.
|
|
y = epochs.events[:, 2]
|
|
_, cv_splits = _set_cv(cv, "classifier", X=y, y=y)
|
|
|
|
parallel, p_func, n_jobs = parallel_func(_run_ems, n_jobs=n_jobs)
|
|
# FIXME this parallelization should be removed.
|
|
# 1) it's numpy computation so it's already efficient,
|
|
# 2) it duplicates the data in RAM,
|
|
# 3) the computation is already super fast.
|
|
out = parallel(
|
|
p_func(_ems_diff, data, cond_idx, train, test) for train, test in cv_splits
|
|
)
|
|
|
|
surrogate_trials, spatial_filter = zip(*out)
|
|
surrogate_trials = np.array(surrogate_trials)
|
|
spatial_filter = np.mean(spatial_filter, axis=0)
|
|
|
|
return surrogate_trials, spatial_filter, epochs.events[:, 2]
|
|
|
|
|
|
def _ems_diff(data0, data1):
|
|
"""Compute the default diff objective function."""
|
|
return np.mean(data0, axis=0) - np.mean(data1, axis=0)
|
|
|
|
|
|
def _run_ems(objective_function, data, cond_idx, train, test):
|
|
"""Run EMS."""
|
|
d = objective_function(*(data[np.intersect1d(c, train)] for c in cond_idx))
|
|
d /= np.sqrt(np.sum(d**2, axis=0))[None, :]
|
|
# compute surrogates
|
|
return np.sum(data[test[0]] * d, axis=0), d
|