针对pulse-transit的工具
This commit is contained in:
948
dist/client/mne/decoding/csp.py
vendored
Normal file
948
dist/client/mne/decoding/csp.py
vendored
Normal file
@@ -0,0 +1,948 @@
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import copy as cp
|
||||
|
||||
import numpy as np
|
||||
from scipy.linalg import eigh
|
||||
|
||||
from .._fiff.meas_info import create_info
|
||||
from ..cov import _compute_rank_raw_array, _regularized_covariance, _smart_eigh
|
||||
from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT
|
||||
from ..evoked import EvokedArray
|
||||
from ..utils import (
|
||||
_check_option,
|
||||
_validate_type,
|
||||
_verbose_safe_false,
|
||||
copy_doc,
|
||||
fill_doc,
|
||||
pinv,
|
||||
)
|
||||
from .base import BaseEstimator
|
||||
from .mixin import TransformerMixin
|
||||
|
||||
|
||||
@fill_doc
|
||||
class CSP(TransformerMixin, BaseEstimator):
|
||||
"""M/EEG signal decomposition using the Common Spatial Patterns (CSP).
|
||||
|
||||
This class can be used as a supervised decomposition to estimate spatial
|
||||
filters for feature extraction. CSP in the context of EEG was first
|
||||
described in :footcite:`KolesEtAl1990`; a comprehensive tutorial on CSP can
|
||||
be found in :footcite:`BlankertzEtAl2008`. Multi-class solving is
|
||||
implemented from :footcite:`Grosse-WentrupBuss2008`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_components : int (default 4)
|
||||
The number of components to decompose M/EEG signals. This number should
|
||||
be set by cross-validation.
|
||||
reg : float | str | None (default None)
|
||||
If not None (same as ``'empirical'``, default), allow regularization
|
||||
for covariance estimation. If float (between 0 and 1), shrinkage is
|
||||
used. For str values, ``reg`` will be passed as ``method`` to
|
||||
:func:`mne.compute_covariance`.
|
||||
log : None | bool (default None)
|
||||
If ``transform_into`` equals ``'average_power'`` and ``log`` is None or
|
||||
True, then apply a log transform to standardize features, else features
|
||||
are z-scored. If ``transform_into`` is ``'csp_space'``, ``log`` must be
|
||||
None.
|
||||
cov_est : 'concat' | 'epoch' (default 'concat')
|
||||
If ``'concat'``, covariance matrices are estimated on concatenated
|
||||
epochs for each class. If ``'epoch'``, covariance matrices are
|
||||
estimated on each epoch separately and then averaged over each class.
|
||||
transform_into : 'average_power' | 'csp_space' (default 'average_power')
|
||||
If 'average_power' then ``self.transform`` will return the average
|
||||
power of each spatial filter. If ``'csp_space'``, ``self.transform``
|
||||
will return the data in CSP space.
|
||||
norm_trace : bool (default False)
|
||||
Normalize class covariance by its trace. Trace normalization is a step
|
||||
of the original CSP algorithm :footcite:`KolesEtAl1990` to eliminate
|
||||
magnitude variations in the EEG between individuals. It is not applied
|
||||
in more recent work :footcite:`BlankertzEtAl2008`,
|
||||
:footcite:`Grosse-WentrupBuss2008` and can have a negative impact on
|
||||
pattern order.
|
||||
cov_method_params : dict | None
|
||||
Parameters to pass to :func:`mne.compute_covariance`.
|
||||
|
||||
.. versionadded:: 0.16
|
||||
%(rank_none)s
|
||||
|
||||
.. versionadded:: 0.17
|
||||
component_order : 'mutual_info' | 'alternate' (default 'mutual_info')
|
||||
If ``'mutual_info'`` order components by decreasing mutual information
|
||||
(in the two-class case this uses a simplification which orders
|
||||
components by decreasing absolute deviation of the eigenvalues from 0.5
|
||||
:footcite:`BarachantEtAl2010`). For the two-class case, ``'alternate'``
|
||||
orders components by starting with the largest eigenvalue, followed by
|
||||
the smallest, the second-to-largest, the second-to-smallest, and so on
|
||||
:footcite:`BlankertzEtAl2008`.
|
||||
|
||||
.. versionadded:: 0.21
|
||||
|
||||
Attributes
|
||||
----------
|
||||
filters_ : ndarray, shape (n_channels, n_channels)
|
||||
If fit, the CSP components used to decompose the data, else None.
|
||||
patterns_ : ndarray, shape (n_channels, n_channels)
|
||||
If fit, the CSP patterns used to restore M/EEG signals, else None.
|
||||
mean_ : ndarray, shape (n_components,)
|
||||
If fit, the mean squared power for each component.
|
||||
std_ : ndarray, shape (n_components,)
|
||||
If fit, the std squared power for each component.
|
||||
|
||||
See Also
|
||||
--------
|
||||
mne.preprocessing.Xdawn, SPoC
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_components=4,
|
||||
reg=None,
|
||||
log=None,
|
||||
cov_est="concat",
|
||||
transform_into="average_power",
|
||||
norm_trace=False,
|
||||
cov_method_params=None,
|
||||
rank=None,
|
||||
component_order="mutual_info",
|
||||
):
|
||||
# Init default CSP
|
||||
if not isinstance(n_components, int):
|
||||
raise ValueError("n_components must be an integer.")
|
||||
self.n_components = n_components
|
||||
self.rank = rank
|
||||
self.reg = reg
|
||||
|
||||
# Init default cov_est
|
||||
if not (cov_est == "concat" or cov_est == "epoch"):
|
||||
raise ValueError("unknown covariance estimation method")
|
||||
self.cov_est = cov_est
|
||||
|
||||
# Init default transform_into
|
||||
self.transform_into = _check_option(
|
||||
"transform_into", transform_into, ["average_power", "csp_space"]
|
||||
)
|
||||
|
||||
# Init default log
|
||||
if transform_into == "average_power":
|
||||
if log is not None and not isinstance(log, bool):
|
||||
raise ValueError(
|
||||
'log must be a boolean if transform_into == "average_power".'
|
||||
)
|
||||
else:
|
||||
if log is not None:
|
||||
raise ValueError('log must be a None if transform_into == "csp_space".')
|
||||
self.log = log
|
||||
|
||||
_validate_type(norm_trace, bool, "norm_trace")
|
||||
self.norm_trace = norm_trace
|
||||
self.cov_method_params = cov_method_params
|
||||
self.component_order = _check_option(
|
||||
"component_order", component_order, ("mutual_info", "alternate")
|
||||
)
|
||||
|
||||
def _check_Xy(self, X, y=None):
|
||||
"""Check input data."""
|
||||
if not isinstance(X, np.ndarray):
|
||||
raise ValueError(f"X should be of type ndarray (got {type(X)}).")
|
||||
if y is not None:
|
||||
if len(X) != len(y) or len(y) < 1:
|
||||
raise ValueError("X and y must have the same length.")
|
||||
if X.ndim < 3:
|
||||
raise ValueError("X must have at least 3 dimensions.")
|
||||
|
||||
def fit(self, X, y):
|
||||
"""Estimate the CSP decomposition on epochs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : ndarray, shape (n_epochs, n_channels, n_times)
|
||||
The data on which to estimate the CSP.
|
||||
y : array, shape (n_epochs,)
|
||||
The class for each epoch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
self : instance of CSP
|
||||
Returns the modified instance.
|
||||
"""
|
||||
self._check_Xy(X, y)
|
||||
|
||||
self._classes = np.unique(y)
|
||||
n_classes = len(self._classes)
|
||||
if n_classes < 2:
|
||||
raise ValueError("n_classes must be >= 2.")
|
||||
if n_classes > 2 and self.component_order == "alternate":
|
||||
raise ValueError(
|
||||
"component_order='alternate' requires two classes, but data contains "
|
||||
f"{n_classes} classes; use component_order='mutual_info' instead."
|
||||
)
|
||||
|
||||
# Convert rank to one that will run
|
||||
_validate_type(self.rank, (dict, None, str), "rank")
|
||||
|
||||
covs, sample_weights = self._compute_covariance_matrices(X, y)
|
||||
eigen_vectors, eigen_values = self._decompose_covs(covs, sample_weights)
|
||||
ix = self._order_components(
|
||||
covs, sample_weights, eigen_vectors, eigen_values, self.component_order
|
||||
)
|
||||
|
||||
eigen_vectors = eigen_vectors[:, ix]
|
||||
|
||||
self.filters_ = eigen_vectors.T
|
||||
self.patterns_ = pinv(eigen_vectors)
|
||||
|
||||
pick_filters = self.filters_[: self.n_components]
|
||||
X = np.asarray([np.dot(pick_filters, epoch) for epoch in X])
|
||||
|
||||
# compute features (mean power)
|
||||
X = (X**2).mean(axis=2)
|
||||
|
||||
# To standardize features
|
||||
self.mean_ = X.mean(axis=0)
|
||||
self.std_ = X.std(axis=0)
|
||||
|
||||
return self
|
||||
|
||||
def transform(self, X):
|
||||
"""Estimate epochs sources given the CSP filters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array, shape (n_epochs, n_channels, n_times)
|
||||
The data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
X : ndarray
|
||||
If self.transform_into == 'average_power' then returns the power of
|
||||
CSP features averaged over time and shape (n_epochs, n_components)
|
||||
If self.transform_into == 'csp_space' then returns the data in CSP
|
||||
space and shape is (n_epochs, n_components, n_times).
|
||||
"""
|
||||
if not isinstance(X, np.ndarray):
|
||||
raise ValueError(f"X should be of type ndarray (got {type(X)}).")
|
||||
if self.filters_ is None:
|
||||
raise RuntimeError(
|
||||
"No filters available. Please first fit CSP decomposition."
|
||||
)
|
||||
|
||||
pick_filters = self.filters_[: self.n_components]
|
||||
X = np.asarray([np.dot(pick_filters, epoch) for epoch in X])
|
||||
|
||||
# compute features (mean band power)
|
||||
if self.transform_into == "average_power":
|
||||
X = (X**2).mean(axis=2)
|
||||
log = True if self.log is None else self.log
|
||||
if log:
|
||||
X = np.log(X)
|
||||
else:
|
||||
X -= self.mean_
|
||||
X /= self.std_
|
||||
return X
|
||||
|
||||
def inverse_transform(self, X):
|
||||
"""Project CSP features back to sensor space.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array, shape (n_epochs, n_components)
|
||||
The data in CSP power space.
|
||||
|
||||
Returns
|
||||
-------
|
||||
X : ndarray
|
||||
The data in sensor space and shape (n_epochs, n_channels, n_components).
|
||||
"""
|
||||
if self.transform_into != "average_power":
|
||||
raise NotImplementedError(
|
||||
"Can only inverse transform CSP features when transform_into is "
|
||||
"'average_power'."
|
||||
)
|
||||
if not (X.ndim == 2 and X.shape[1] == self.n_components):
|
||||
raise ValueError(
|
||||
f"X must be 2D with X[1]={self.n_components}, got {X.shape=}"
|
||||
)
|
||||
return X[:, np.newaxis, :] * self.patterns_[: self.n_components].T
|
||||
|
||||
@copy_doc(TransformerMixin.fit_transform)
|
||||
def fit_transform(self, X, y, **fit_params): # noqa: D102
|
||||
return super().fit_transform(X, y=y, **fit_params)
|
||||
|
||||
@fill_doc
|
||||
def plot_patterns(
|
||||
self,
|
||||
info,
|
||||
components=None,
|
||||
*,
|
||||
average=None,
|
||||
ch_type=None,
|
||||
scalings=None,
|
||||
sensors=True,
|
||||
show_names=False,
|
||||
mask=None,
|
||||
mask_params=None,
|
||||
contours=6,
|
||||
outlines="head",
|
||||
sphere=None,
|
||||
image_interp=_INTERPOLATION_DEFAULT,
|
||||
extrapolate=_EXTRAPOLATE_DEFAULT,
|
||||
border=_BORDER_DEFAULT,
|
||||
res=64,
|
||||
size=1,
|
||||
cmap="RdBu_r",
|
||||
vlim=(None, None),
|
||||
cnorm=None,
|
||||
colorbar=True,
|
||||
cbar_fmt="%3.1f",
|
||||
units=None,
|
||||
axes=None,
|
||||
name_format="CSP%01d",
|
||||
nrows=1,
|
||||
ncols="auto",
|
||||
show=True,
|
||||
):
|
||||
"""Plot topographic patterns of components.
|
||||
|
||||
The patterns explain how the measured data was generated from the
|
||||
neural sources (a.k.a. the forward model).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
%(info_not_none)s Used for fitting. If not available, consider using
|
||||
:func:`mne.create_info`.
|
||||
components : float | array of float | None
|
||||
The patterns to plot. If ``None``, all components will be shown.
|
||||
%(average_plot_evoked_topomap)s
|
||||
%(ch_type_topomap)s
|
||||
scalings : dict | float | None
|
||||
The scalings of the channel types to be applied for plotting.
|
||||
If None, defaults to ``dict(eeg=1e6, grad=1e13, mag=1e15)``.
|
||||
%(sensors_topomap)s
|
||||
%(show_names_topomap)s
|
||||
%(mask_patterns_topomap)s
|
||||
%(mask_params_topomap)s
|
||||
%(contours_topomap)s
|
||||
%(outlines_topomap)s
|
||||
%(sphere_topomap_auto)s
|
||||
%(image_interp_topomap)s
|
||||
%(extrapolate_topomap)s
|
||||
|
||||
.. versionadded:: 1.3
|
||||
%(border_topomap)s
|
||||
|
||||
.. versionadded:: 1.3
|
||||
%(res_topomap)s
|
||||
%(size_topomap)s
|
||||
%(cmap_topomap)s
|
||||
%(vlim_plot_topomap)s
|
||||
|
||||
.. versionadded:: 1.3
|
||||
%(cnorm)s
|
||||
|
||||
.. versionadded:: 1.3
|
||||
%(colorbar_topomap)s
|
||||
%(cbar_fmt_topomap)s
|
||||
%(units_topomap)s
|
||||
%(axes_evoked_plot_topomap)s
|
||||
name_format : str
|
||||
String format for topomap values. Defaults to "CSP%%01d".
|
||||
%(nrows_ncols_topomap)s
|
||||
|
||||
.. versionadded:: 1.3
|
||||
%(show)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
fig : instance of matplotlib.figure.Figure
|
||||
The figure.
|
||||
"""
|
||||
if units is None:
|
||||
units = "AU"
|
||||
if components is None:
|
||||
components = np.arange(self.n_components)
|
||||
|
||||
# set sampling frequency to have 1 component per time point
|
||||
info = cp.deepcopy(info)
|
||||
with info._unlock():
|
||||
info["sfreq"] = 1.0
|
||||
# create an evoked
|
||||
patterns = EvokedArray(self.patterns_.T, info, tmin=0)
|
||||
# the call plot_topomap
|
||||
fig = patterns.plot_topomap(
|
||||
times=components,
|
||||
average=average,
|
||||
ch_type=ch_type,
|
||||
scalings=scalings,
|
||||
sensors=sensors,
|
||||
show_names=show_names,
|
||||
mask=mask,
|
||||
mask_params=mask_params,
|
||||
contours=contours,
|
||||
outlines=outlines,
|
||||
sphere=sphere,
|
||||
image_interp=image_interp,
|
||||
extrapolate=extrapolate,
|
||||
border=border,
|
||||
res=res,
|
||||
size=size,
|
||||
cmap=cmap,
|
||||
vlim=vlim,
|
||||
cnorm=cnorm,
|
||||
colorbar=colorbar,
|
||||
cbar_fmt=cbar_fmt,
|
||||
units=units,
|
||||
axes=axes,
|
||||
time_format=name_format,
|
||||
nrows=nrows,
|
||||
ncols=ncols,
|
||||
show=show,
|
||||
)
|
||||
return fig
|
||||
|
||||
@fill_doc
|
||||
def plot_filters(
|
||||
self,
|
||||
info,
|
||||
components=None,
|
||||
*,
|
||||
average=None,
|
||||
ch_type=None,
|
||||
scalings=None,
|
||||
sensors=True,
|
||||
show_names=False,
|
||||
mask=None,
|
||||
mask_params=None,
|
||||
contours=6,
|
||||
outlines="head",
|
||||
sphere=None,
|
||||
image_interp=_INTERPOLATION_DEFAULT,
|
||||
extrapolate=_EXTRAPOLATE_DEFAULT,
|
||||
border=_BORDER_DEFAULT,
|
||||
res=64,
|
||||
size=1,
|
||||
cmap="RdBu_r",
|
||||
vlim=(None, None),
|
||||
cnorm=None,
|
||||
colorbar=True,
|
||||
cbar_fmt="%3.1f",
|
||||
units=None,
|
||||
axes=None,
|
||||
name_format="CSP%01d",
|
||||
nrows=1,
|
||||
ncols="auto",
|
||||
show=True,
|
||||
):
|
||||
"""Plot topographic filters of components.
|
||||
|
||||
The filters are used to extract discriminant neural sources from
|
||||
the measured data (a.k.a. the backward model).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
%(info_not_none)s Used for fitting. If not available, consider using
|
||||
:func:`mne.create_info`.
|
||||
components : float | array of float | None
|
||||
The patterns to plot. If ``None``, all components will be shown.
|
||||
%(average_plot_evoked_topomap)s
|
||||
%(ch_type_topomap)s
|
||||
scalings : dict | float | None
|
||||
The scalings of the channel types to be applied for plotting.
|
||||
If None, defaults to ``dict(eeg=1e6, grad=1e13, mag=1e15)``.
|
||||
%(sensors_topomap)s
|
||||
%(show_names_topomap)s
|
||||
%(mask_patterns_topomap)s
|
||||
%(mask_params_topomap)s
|
||||
%(contours_topomap)s
|
||||
%(outlines_topomap)s
|
||||
%(sphere_topomap_auto)s
|
||||
%(image_interp_topomap)s
|
||||
%(extrapolate_topomap)s
|
||||
|
||||
.. versionadded:: 1.3
|
||||
%(border_topomap)s
|
||||
|
||||
.. versionadded:: 1.3
|
||||
%(res_topomap)s
|
||||
%(size_topomap)s
|
||||
%(cmap_topomap)s
|
||||
%(vlim_plot_topomap_psd)s
|
||||
|
||||
.. versionadded:: 1.3
|
||||
%(cnorm)s
|
||||
|
||||
.. versionadded:: 1.3
|
||||
%(colorbar_topomap)s
|
||||
%(cbar_fmt_topomap)s
|
||||
%(units_topomap)s
|
||||
%(axes_evoked_plot_topomap)s
|
||||
name_format : str
|
||||
String format for topomap values. Defaults to "CSP%%01d".
|
||||
%(nrows_ncols_topomap)s
|
||||
|
||||
.. versionadded:: 1.3
|
||||
%(show)s
|
||||
|
||||
Returns
|
||||
-------
|
||||
fig : instance of matplotlib.figure.Figure
|
||||
The figure.
|
||||
"""
|
||||
if units is None:
|
||||
units = "AU"
|
||||
if components is None:
|
||||
components = np.arange(self.n_components)
|
||||
|
||||
# set sampling frequency to have 1 component per time point
|
||||
info = cp.deepcopy(info)
|
||||
with info._unlock():
|
||||
info["sfreq"] = 1.0
|
||||
# create an evoked
|
||||
filters = EvokedArray(self.filters_.T, info, tmin=0)
|
||||
# the call plot_topomap
|
||||
fig = filters.plot_topomap(
|
||||
times=components,
|
||||
average=average,
|
||||
ch_type=ch_type,
|
||||
scalings=scalings,
|
||||
sensors=sensors,
|
||||
show_names=show_names,
|
||||
mask=mask,
|
||||
mask_params=mask_params,
|
||||
contours=contours,
|
||||
outlines=outlines,
|
||||
sphere=sphere,
|
||||
image_interp=image_interp,
|
||||
extrapolate=extrapolate,
|
||||
border=border,
|
||||
res=res,
|
||||
size=size,
|
||||
cmap=cmap,
|
||||
vlim=vlim,
|
||||
cnorm=cnorm,
|
||||
colorbar=colorbar,
|
||||
cbar_fmt=cbar_fmt,
|
||||
units=units,
|
||||
axes=axes,
|
||||
time_format=name_format,
|
||||
nrows=nrows,
|
||||
ncols=ncols,
|
||||
show=show,
|
||||
)
|
||||
return fig
|
||||
|
||||
def _compute_covariance_matrices(self, X, y):
|
||||
_, n_channels, _ = X.shape
|
||||
|
||||
if self.cov_est == "concat":
|
||||
cov_estimator = self._concat_cov
|
||||
elif self.cov_est == "epoch":
|
||||
cov_estimator = self._epoch_cov
|
||||
|
||||
# Someday we could allow the user to pass this, then we wouldn't need to convert
|
||||
# but in the meantime they can use a pipeline with a scaler
|
||||
self._info = create_info(n_channels, 1000.0, "mag")
|
||||
if isinstance(self.rank, dict):
|
||||
self._rank = {"mag": sum(self.rank.values())}
|
||||
else:
|
||||
self._rank = _compute_rank_raw_array(
|
||||
X.transpose(1, 0, 2).reshape(X.shape[1], -1),
|
||||
self._info,
|
||||
rank=self.rank,
|
||||
scalings=None,
|
||||
log_ch_type="data",
|
||||
)
|
||||
|
||||
covs = []
|
||||
sample_weights = []
|
||||
for ci, this_class in enumerate(self._classes):
|
||||
cov, weight = cov_estimator(
|
||||
X[y == this_class],
|
||||
cov_kind=f"class={this_class}",
|
||||
log_rank=ci == 0,
|
||||
)
|
||||
|
||||
if self.norm_trace:
|
||||
cov /= np.trace(cov)
|
||||
|
||||
covs.append(cov)
|
||||
sample_weights.append(weight)
|
||||
|
||||
return np.stack(covs), np.array(sample_weights)
|
||||
|
||||
def _concat_cov(self, x_class, *, cov_kind, log_rank):
|
||||
"""Concatenate epochs before computing the covariance."""
|
||||
_, n_channels, _ = x_class.shape
|
||||
|
||||
x_class = x_class.transpose(1, 0, 2).reshape(n_channels, -1)
|
||||
cov = _regularized_covariance(
|
||||
x_class,
|
||||
reg=self.reg,
|
||||
method_params=self.cov_method_params,
|
||||
rank=self._rank,
|
||||
info=self._info,
|
||||
cov_kind=cov_kind,
|
||||
log_rank=log_rank,
|
||||
log_ch_type="data",
|
||||
)
|
||||
weight = x_class.shape[0]
|
||||
|
||||
return cov, weight
|
||||
|
||||
def _epoch_cov(self, x_class, *, cov_kind, log_rank):
|
||||
"""Mean of per-epoch covariances."""
|
||||
cov = sum(
|
||||
_regularized_covariance(
|
||||
this_X,
|
||||
reg=self.reg,
|
||||
method_params=self.cov_method_params,
|
||||
rank=self._rank,
|
||||
info=self._info,
|
||||
cov_kind=cov_kind,
|
||||
log_rank=log_rank and ii == 0,
|
||||
log_ch_type="data",
|
||||
)
|
||||
for ii, this_X in enumerate(x_class)
|
||||
)
|
||||
cov /= len(x_class)
|
||||
weight = len(x_class)
|
||||
|
||||
return cov, weight
|
||||
|
||||
def _decompose_covs(self, covs, sample_weights):
|
||||
n_classes = len(covs)
|
||||
n_channels = covs[0].shape[0]
|
||||
assert self._rank is not None # should happen in _compute_covariance_matrices
|
||||
_, sub_vec, mask = _smart_eigh(
|
||||
covs.mean(0),
|
||||
self._info,
|
||||
self._rank,
|
||||
proj_subspace=True,
|
||||
do_compute_rank=False,
|
||||
log_ch_type="data",
|
||||
verbose=_verbose_safe_false(),
|
||||
)
|
||||
sub_vec = sub_vec[mask]
|
||||
covs = np.array([sub_vec @ cov @ sub_vec.T for cov in covs], float)
|
||||
assert covs[0].shape == (mask.sum(),) * 2
|
||||
if n_classes == 2:
|
||||
eigen_values, eigen_vectors = eigh(covs[0], covs.sum(0))
|
||||
else:
|
||||
# The multiclass case is adapted from
|
||||
# http://github.com/alexandrebarachant/pyRiemann
|
||||
eigen_vectors, D = _ajd_pham(covs)
|
||||
eigen_vectors = self._normalize_eigenvectors(
|
||||
eigen_vectors.T, covs, sample_weights
|
||||
)
|
||||
eigen_values = None
|
||||
# project back
|
||||
eigen_vectors = sub_vec.T @ eigen_vectors
|
||||
assert eigen_vectors.shape == (n_channels, mask.sum())
|
||||
return eigen_vectors, eigen_values
|
||||
|
||||
def _compute_mutual_info(self, covs, sample_weights, eigen_vectors):
|
||||
class_probas = sample_weights / sample_weights.sum()
|
||||
|
||||
mutual_info = []
|
||||
for jj in range(eigen_vectors.shape[1]):
|
||||
aa, bb = 0, 0
|
||||
for cov, prob in zip(covs, class_probas):
|
||||
tmp = np.dot(np.dot(eigen_vectors[:, jj].T, cov), eigen_vectors[:, jj])
|
||||
aa += prob * np.log(np.sqrt(tmp))
|
||||
bb += prob * (tmp**2 - 1)
|
||||
mi = -(aa + (3.0 / 16) * (bb**2))
|
||||
mutual_info.append(mi)
|
||||
|
||||
return mutual_info
|
||||
|
||||
def _normalize_eigenvectors(self, eigen_vectors, covs, sample_weights):
|
||||
# Here we apply an euclidean mean. See pyRiemann for other metrics
|
||||
mean_cov = np.average(covs, axis=0, weights=sample_weights)
|
||||
|
||||
for ii in range(eigen_vectors.shape[1]):
|
||||
tmp = np.dot(np.dot(eigen_vectors[:, ii].T, mean_cov), eigen_vectors[:, ii])
|
||||
eigen_vectors[:, ii] /= np.sqrt(tmp)
|
||||
return eigen_vectors
|
||||
|
||||
def _order_components(
|
||||
self, covs, sample_weights, eigen_vectors, eigen_values, component_order
|
||||
):
|
||||
n_classes = len(self._classes)
|
||||
if component_order == "mutual_info" and n_classes > 2:
|
||||
mutual_info = self._compute_mutual_info(covs, sample_weights, eigen_vectors)
|
||||
ix = np.argsort(mutual_info)[::-1]
|
||||
elif component_order == "mutual_info" and n_classes == 2:
|
||||
ix = np.argsort(np.abs(eigen_values - 0.5))[::-1]
|
||||
elif component_order == "alternate" and n_classes == 2:
|
||||
i = np.argsort(eigen_values)
|
||||
ix = np.empty_like(i)
|
||||
ix[1::2] = i[: len(i) // 2]
|
||||
ix[0::2] = i[len(i) // 2 :][::-1]
|
||||
return ix
|
||||
|
||||
|
||||
def _ajd_pham(X, eps=1e-6, max_iter=15):
|
||||
"""Approximate joint diagonalization based on Pham's algorithm.
|
||||
|
||||
This is a direct implementation of the PHAM's AJD algorithm [1].
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : ndarray, shape (n_epochs, n_channels, n_channels)
|
||||
A set of covariance matrices to diagonalize.
|
||||
eps : float, default 1e-6
|
||||
The tolerance for stopping criterion.
|
||||
max_iter : int, default 1000
|
||||
The maximum number of iteration to reach convergence.
|
||||
|
||||
Returns
|
||||
-------
|
||||
V : ndarray, shape (n_channels, n_channels)
|
||||
The diagonalizer.
|
||||
D : ndarray, shape (n_epochs, n_channels, n_channels)
|
||||
The set of quasi diagonal matrices.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] Pham, Dinh Tuan. "Joint approximate diagonalization of positive
|
||||
definite Hermitian matrices." SIAM Journal on Matrix Analysis and
|
||||
Applications 22, no. 4 (2001): 1136-1152.
|
||||
|
||||
"""
|
||||
# Adapted from http://github.com/alexandrebarachant/pyRiemann
|
||||
n_epochs = X.shape[0]
|
||||
|
||||
# Reshape input matrix
|
||||
A = np.concatenate(X, axis=0).T
|
||||
|
||||
# Init variables
|
||||
n_times, n_m = A.shape
|
||||
V = np.eye(n_times)
|
||||
epsilon = n_times * (n_times - 1) * eps
|
||||
|
||||
for it in range(max_iter):
|
||||
decr = 0
|
||||
for ii in range(1, n_times):
|
||||
for jj in range(ii):
|
||||
Ii = np.arange(ii, n_m, n_times)
|
||||
Ij = np.arange(jj, n_m, n_times)
|
||||
|
||||
c1 = A[ii, Ii]
|
||||
c2 = A[jj, Ij]
|
||||
|
||||
g12 = np.mean(A[ii, Ij] / c1)
|
||||
g21 = np.mean(A[ii, Ij] / c2)
|
||||
|
||||
omega21 = np.mean(c1 / c2)
|
||||
omega12 = np.mean(c2 / c1)
|
||||
omega = np.sqrt(omega12 * omega21)
|
||||
|
||||
tmp = np.sqrt(omega21 / omega12)
|
||||
tmp1 = (tmp * g12 + g21) / (omega + 1)
|
||||
tmp2 = (tmp * g12 - g21) / max(omega - 1, 1e-9)
|
||||
|
||||
h12 = tmp1 + tmp2
|
||||
h21 = np.conj((tmp1 - tmp2) / tmp)
|
||||
|
||||
decr += n_epochs * (g12 * np.conj(h12) + g21 * h21) / 2.0
|
||||
|
||||
tmp = 1 + 1.0j * 0.5 * np.imag(h12 * h21)
|
||||
tmp = np.real(tmp + np.sqrt(tmp**2 - h12 * h21))
|
||||
tau = np.array([[1, -h12 / tmp], [-h21 / tmp, 1]])
|
||||
|
||||
A[[ii, jj], :] = np.dot(tau, A[[ii, jj], :])
|
||||
tmp = np.c_[A[:, Ii], A[:, Ij]]
|
||||
tmp = np.reshape(tmp, (n_times * n_epochs, 2), order="F")
|
||||
tmp = np.dot(tmp, tau.T)
|
||||
|
||||
tmp = np.reshape(tmp, (n_times, n_epochs * 2), order="F")
|
||||
A[:, Ii] = tmp[:, :n_epochs]
|
||||
A[:, Ij] = tmp[:, n_epochs:]
|
||||
V[[ii, jj], :] = np.dot(tau, V[[ii, jj], :])
|
||||
if decr < epsilon:
|
||||
break
|
||||
D = np.reshape(A, (n_times, -1, n_times)).transpose(1, 0, 2)
|
||||
return V, D
|
||||
|
||||
|
||||
@fill_doc
|
||||
class SPoC(CSP):
|
||||
"""Implementation of the SPoC spatial filtering.
|
||||
|
||||
Source Power Comodulation (SPoC) :footcite:`DahneEtAl2014` allows to
|
||||
extract spatial filters and
|
||||
patterns by using a target (continuous) variable in the decomposition
|
||||
process in order to give preference to components whose power correlates
|
||||
with the target variable.
|
||||
|
||||
SPoC can be seen as an extension of the CSP driven by a continuous
|
||||
variable rather than a discrete variable. Typical applications include
|
||||
extraction of motor patterns using EMG power or audio patterns using sound
|
||||
envelope.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_components : int
|
||||
The number of components to decompose M/EEG signals.
|
||||
reg : float | str | None (default None)
|
||||
If not None (same as ``'empirical'``, default), allow
|
||||
regularization for covariance estimation.
|
||||
If float, shrinkage is used (0 <= shrinkage <= 1).
|
||||
For str options, ``reg`` will be passed to ``method`` to
|
||||
:func:`mne.compute_covariance`.
|
||||
log : None | bool (default None)
|
||||
If transform_into == 'average_power' and log is None or True, then
|
||||
applies a log transform to standardize the features, else the features
|
||||
are z-scored. If transform_into == 'csp_space', then log must be None.
|
||||
transform_into : {'average_power', 'csp_space'}
|
||||
If 'average_power' then self.transform will return the average power of
|
||||
each spatial filter. If 'csp_space' self.transform will return the data
|
||||
in CSP space. Defaults to 'average_power'.
|
||||
cov_method_params : dict | None
|
||||
Parameters to pass to :func:`mne.compute_covariance`.
|
||||
|
||||
.. versionadded:: 0.16
|
||||
%(rank_none)s
|
||||
|
||||
.. versionadded:: 0.17
|
||||
|
||||
Attributes
|
||||
----------
|
||||
filters_ : ndarray, shape (n_channels, n_channels)
|
||||
If fit, the SPoC spatial filters, else None.
|
||||
patterns_ : ndarray, shape (n_channels, n_channels)
|
||||
If fit, the SPoC spatial patterns, else None.
|
||||
mean_ : ndarray, shape (n_components,)
|
||||
If fit, the mean squared power for each component.
|
||||
std_ : ndarray, shape (n_components,)
|
||||
If fit, the std squared power for each component.
|
||||
|
||||
See Also
|
||||
--------
|
||||
mne.preprocessing.Xdawn, CSP
|
||||
|
||||
References
|
||||
----------
|
||||
.. footbibliography::
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_components=4,
|
||||
reg=None,
|
||||
log=None,
|
||||
transform_into="average_power",
|
||||
cov_method_params=None,
|
||||
rank=None,
|
||||
):
|
||||
"""Init of SPoC."""
|
||||
super().__init__(
|
||||
n_components=n_components,
|
||||
reg=reg,
|
||||
log=log,
|
||||
cov_est="epoch",
|
||||
norm_trace=False,
|
||||
transform_into=transform_into,
|
||||
rank=rank,
|
||||
cov_method_params=cov_method_params,
|
||||
)
|
||||
# Covariance estimation have to be done on the single epoch level,
|
||||
# unlike CSP where covariance estimation can also be achieved through
|
||||
# concatenation of all epochs from the same class.
|
||||
delattr(self, "cov_est")
|
||||
delattr(self, "norm_trace")
|
||||
|
||||
def fit(self, X, y):
|
||||
"""Estimate the SPoC decomposition on epochs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : ndarray, shape (n_epochs, n_channels, n_times)
|
||||
The data on which to estimate the SPoC.
|
||||
y : array, shape (n_epochs,)
|
||||
The class for each epoch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
self : instance of SPoC
|
||||
Returns the modified instance.
|
||||
"""
|
||||
self._check_Xy(X, y)
|
||||
|
||||
if len(np.unique(y)) < 2:
|
||||
raise ValueError("y must have at least two distinct values.")
|
||||
|
||||
# The following code is directly copied from pyRiemann
|
||||
|
||||
# Normalize target variable
|
||||
target = y.astype(np.float64)
|
||||
target -= target.mean()
|
||||
target /= target.std()
|
||||
|
||||
n_epochs, n_channels = X.shape[:2]
|
||||
|
||||
# Estimate single trial covariance
|
||||
covs = np.empty((n_epochs, n_channels, n_channels))
|
||||
for ii, epoch in enumerate(X):
|
||||
covs[ii] = _regularized_covariance(
|
||||
epoch,
|
||||
reg=self.reg,
|
||||
method_params=self.cov_method_params,
|
||||
rank=self.rank,
|
||||
log_ch_type="data",
|
||||
log_rank=ii == 0,
|
||||
)
|
||||
|
||||
C = covs.mean(0)
|
||||
Cz = np.mean(covs * target[:, np.newaxis, np.newaxis], axis=0)
|
||||
|
||||
# solve eigenvalue decomposition
|
||||
evals, evecs = eigh(Cz, C)
|
||||
evals = evals.real
|
||||
evecs = evecs.real
|
||||
# sort vectors
|
||||
ix = np.argsort(np.abs(evals))[::-1]
|
||||
|
||||
# sort eigenvectors
|
||||
evecs = evecs[:, ix].T
|
||||
|
||||
# spatial patterns
|
||||
self.patterns_ = pinv(evecs).T # n_channels x n_channels
|
||||
self.filters_ = evecs # n_channels x n_channels
|
||||
|
||||
pick_filters = self.filters_[: self.n_components]
|
||||
X = np.asarray([np.dot(pick_filters, epoch) for epoch in X])
|
||||
|
||||
# compute features (mean band power)
|
||||
X = (X**2).mean(axis=-1)
|
||||
|
||||
# To standardize features
|
||||
self.mean_ = X.mean(axis=0)
|
||||
self.std_ = X.std(axis=0)
|
||||
|
||||
return self
|
||||
|
||||
def transform(self, X):
|
||||
"""Estimate epochs sources given the SPoC filters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array, shape (n_epochs, n_channels, n_times)
|
||||
The data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
X : ndarray
|
||||
If self.transform_into == 'average_power' then returns the power of
|
||||
CSP features averaged over time and shape (n_epochs, n_components)
|
||||
If self.transform_into == 'csp_space' then returns the data in CSP
|
||||
space and shape is (n_epochs, n_components, n_times).
|
||||
"""
|
||||
return super().transform(X)
|
||||
Reference in New Issue
Block a user