397 lines
14 KiB
Python
397 lines
14 KiB
Python
# Authors: The MNE-Python contributors.
|
||
# License: BSD-3-Clause
|
||
# Copyright the MNE-Python contributors.
|
||
|
||
import numpy as np
|
||
from scipy.linalg import eigh
|
||
|
||
from .._fiff.pick import _picks_to_idx
|
||
from ..cov import Covariance, _regularized_covariance
|
||
from ..defaults import _handle_default
|
||
from ..filter import filter_data
|
||
from ..fixes import BaseEstimator
|
||
from ..rank import compute_rank
|
||
from ..time_frequency import psd_array_welch
|
||
from ..utils import (
|
||
_check_option,
|
||
_time_mask,
|
||
_validate_type,
|
||
_verbose_safe_false,
|
||
fill_doc,
|
||
logger,
|
||
)
|
||
from .mixin import TransformerMixin
|
||
|
||
|
||
@fill_doc
|
||
class SSD(BaseEstimator, TransformerMixin):
|
||
"""
|
||
Signal decomposition using the Spatio-Spectral Decomposition (SSD).
|
||
|
||
SSD seeks to maximize the power at a frequency band of interest while
|
||
simultaneously minimizing it at the flanking (surrounding) frequency bins
|
||
(considered noise). It extremizes the covariance matrices associated with
|
||
signal and noise :footcite:`NikulinEtAl2011`.
|
||
|
||
SSD can either be used as a dimensionality reduction method or a
|
||
‘denoised’ low rank factorization method :footcite:`HaufeEtAl2014b`.
|
||
|
||
Parameters
|
||
----------
|
||
%(info_not_none)s Must match the input data.
|
||
filt_params_signal : dict
|
||
Filtering for the frequencies of interest.
|
||
filt_params_noise : dict
|
||
Filtering for the frequencies of non-interest.
|
||
reg : float | str | None (default)
|
||
Which covariance estimator to use.
|
||
If not None (same as 'empirical'), allow regularization for covariance
|
||
estimation. If float, shrinkage is used (0 <= shrinkage <= 1). For str
|
||
options, reg will be passed to method :func:`mne.compute_covariance`.
|
||
n_components : int | None (default None)
|
||
The number of components to extract from the signal.
|
||
If None, the number of components equal to the rank of the data are
|
||
returned (see ``rank``).
|
||
picks : array of int | None (default None)
|
||
The indices of good channels.
|
||
sort_by_spectral_ratio : bool (default True)
|
||
If set to True, the components are sorted according to the spectral
|
||
ratio.
|
||
See Eq. (24) in :footcite:`NikulinEtAl2011`.
|
||
return_filtered : bool (default False)
|
||
If return_filtered is True, data is bandpassed and projected onto the
|
||
SSD components.
|
||
n_fft : int (default None)
|
||
If sort_by_spectral_ratio is set to True, then the SSD sources will be
|
||
sorted according to their spectral ratio which is calculated based on
|
||
:func:`mne.time_frequency.psd_array_welch`. The n_fft parameter sets the
|
||
length of FFT used.
|
||
See :func:`mne.time_frequency.psd_array_welch` for more information.
|
||
cov_method_params : dict | None (default None)
|
||
As in :class:`mne.decoding.SPoC`
|
||
The default is None.
|
||
rank : None | dict | ‘info’ | ‘full’
|
||
As in :class:`mne.decoding.SPoC`
|
||
This controls the rank computation that can be read from the
|
||
measurement info or estimated from the data, which determines the
|
||
maximum possible number of components.
|
||
See Notes of :func:`mne.compute_rank` for details.
|
||
We recommend to use 'full' when working with epoched data.
|
||
|
||
Attributes
|
||
----------
|
||
filters_ : array, shape (n_channels, n_components)
|
||
The spatial filters to be multiplied with the signal.
|
||
patterns_ : array, shape (n_components, n_channels)
|
||
The patterns for reconstructing the signal from the filtered data.
|
||
|
||
References
|
||
----------
|
||
.. footbibliography::
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
info,
|
||
filt_params_signal,
|
||
filt_params_noise,
|
||
reg=None,
|
||
n_components=None,
|
||
picks=None,
|
||
sort_by_spectral_ratio=True,
|
||
return_filtered=False,
|
||
n_fft=None,
|
||
cov_method_params=None,
|
||
rank=None,
|
||
):
|
||
"""Initialize instance."""
|
||
dicts = {"signal": filt_params_signal, "noise": filt_params_noise}
|
||
for param, dd in [("l", 0), ("h", 0), ("l", 1), ("h", 1)]:
|
||
key = ("signal", "noise")[dd]
|
||
if param + "_freq" not in dicts[key]:
|
||
raise ValueError(
|
||
f"{param + '_freq'} must be defined in filter parameters for {key}"
|
||
)
|
||
val = dicts[key][param + "_freq"]
|
||
if not isinstance(val, (int, float)):
|
||
_validate_type(val, ("numeric",), f"{key} {param}_freq")
|
||
# check freq bands
|
||
if (
|
||
filt_params_noise["l_freq"] > filt_params_signal["l_freq"]
|
||
or filt_params_signal["h_freq"] > filt_params_noise["h_freq"]
|
||
):
|
||
raise ValueError(
|
||
"Wrongly specified frequency bands!\n"
|
||
"The signal band-pass must be within the noise "
|
||
"band-pass!"
|
||
)
|
||
self.picks_ = _picks_to_idx(info, picks, none="data", exclude="bads")
|
||
del picks
|
||
ch_types = info.get_channel_types(picks=self.picks_, unique=True)
|
||
if len(ch_types) > 1:
|
||
raise ValueError(
|
||
"At this point SSD only supports fitting "
|
||
"single channel types. Your info has %i types" % (len(ch_types))
|
||
)
|
||
self.info = info
|
||
self.freqs_signal = (filt_params_signal["l_freq"], filt_params_signal["h_freq"])
|
||
self.freqs_noise = (filt_params_noise["l_freq"], filt_params_noise["h_freq"])
|
||
self.filt_params_signal = filt_params_signal
|
||
self.filt_params_noise = filt_params_noise
|
||
# check if boolean
|
||
if not isinstance(sort_by_spectral_ratio, (bool)):
|
||
raise ValueError("sort_by_spectral_ratio must be boolean")
|
||
self.sort_by_spectral_ratio = sort_by_spectral_ratio
|
||
if n_fft is None:
|
||
self.n_fft = int(self.info["sfreq"])
|
||
else:
|
||
self.n_fft = int(n_fft)
|
||
# check if boolean
|
||
if not isinstance(return_filtered, (bool)):
|
||
raise ValueError("return_filtered must be boolean")
|
||
self.return_filtered = return_filtered
|
||
self.reg = reg
|
||
self.n_components = n_components
|
||
self.rank = rank
|
||
self.cov_method_params = cov_method_params
|
||
|
||
def _check_X(self, X):
|
||
"""Check input data."""
|
||
_validate_type(X, np.ndarray, "X")
|
||
_check_option("X.ndim", X.ndim, (2, 3))
|
||
n_chan = X.shape[-2]
|
||
if n_chan != self.info["nchan"]:
|
||
raise ValueError(
|
||
"Info must match the input data."
|
||
"Found %i channels but expected %i." % (n_chan, self.info["nchan"])
|
||
)
|
||
|
||
def fit(self, X, y=None):
|
||
"""Estimate the SSD decomposition on raw or epoched data.
|
||
|
||
Parameters
|
||
----------
|
||
X : array, shape ([n_epochs, ]n_channels, n_times)
|
||
The input data from which to estimate the SSD. Either 2D array
|
||
obtained from continuous data or 3D array obtained from epoched
|
||
data.
|
||
y : None
|
||
Ignored; exists for compatibility with scikit-learn pipelines.
|
||
|
||
Returns
|
||
-------
|
||
self : instance of SSD
|
||
Returns the modified instance.
|
||
"""
|
||
self._check_X(X)
|
||
X_aux = X[..., self.picks_, :]
|
||
|
||
X_signal = filter_data(X_aux, self.info["sfreq"], **self.filt_params_signal)
|
||
X_noise = filter_data(X_aux, self.info["sfreq"], **self.filt_params_noise)
|
||
X_noise -= X_signal
|
||
if X.ndim == 3:
|
||
X_signal = np.hstack(X_signal)
|
||
X_noise = np.hstack(X_noise)
|
||
|
||
# prevent rank change when computing cov with rank='full'
|
||
cov_signal = _regularized_covariance(
|
||
X_signal,
|
||
reg=self.reg,
|
||
method_params=self.cov_method_params,
|
||
rank="full",
|
||
info=self.info,
|
||
)
|
||
cov_noise = _regularized_covariance(
|
||
X_noise,
|
||
reg=self.reg,
|
||
method_params=self.cov_method_params,
|
||
rank="full",
|
||
info=self.info,
|
||
)
|
||
|
||
# project cov to rank subspace
|
||
cov_signal, cov_noise, rank_proj = _dimensionality_reduction(
|
||
cov_signal, cov_noise, self.info, self.rank
|
||
)
|
||
|
||
eigvals_, eigvects_ = eigh(cov_signal, cov_noise)
|
||
# sort in descending order
|
||
ix = np.argsort(eigvals_)[::-1]
|
||
self.eigvals_ = eigvals_[ix]
|
||
# project back to sensor space
|
||
self.filters_ = np.matmul(rank_proj, eigvects_[:, ix])
|
||
self.patterns_ = np.linalg.pinv(self.filters_)
|
||
|
||
# We assume that ordering by spectral ratio is more important
|
||
# than the initial ordering. This ordering should be also learned when
|
||
# fitting.
|
||
X_ssd = self.filters_.T @ X[..., self.picks_, :]
|
||
sorter_spec = Ellipsis
|
||
if self.sort_by_spectral_ratio:
|
||
_, sorter_spec = self.get_spectral_ratio(ssd_sources=X_ssd)
|
||
self.sorter_spec = sorter_spec
|
||
logger.info("Done.")
|
||
return self
|
||
|
||
def transform(self, X):
|
||
"""Estimate epochs sources given the SSD filters.
|
||
|
||
Parameters
|
||
----------
|
||
X : array, shape ([n_epochs, ]n_channels, n_times)
|
||
The input data from which to estimate the SSD. Either 2D array
|
||
obtained from continuous data or 3D array obtained from epoched
|
||
data.
|
||
|
||
Returns
|
||
-------
|
||
X_ssd : array, shape ([n_epochs, ]n_components, n_times)
|
||
The processed data.
|
||
"""
|
||
self._check_X(X)
|
||
if self.filters_ is None:
|
||
raise RuntimeError("No filters available. Please first call fit")
|
||
if self.return_filtered:
|
||
X_aux = X[..., self.picks_, :]
|
||
X = filter_data(X_aux, self.info["sfreq"], **self.filt_params_signal)
|
||
X_ssd = self.filters_.T @ X[..., self.picks_, :]
|
||
if X.ndim == 2:
|
||
X_ssd = X_ssd[self.sorter_spec][: self.n_components]
|
||
else:
|
||
X_ssd = X_ssd[:, self.sorter_spec, :][:, : self.n_components, :]
|
||
return X_ssd
|
||
|
||
def get_spectral_ratio(self, ssd_sources):
|
||
"""Get the spectal signal-to-noise ratio for each spatial filter.
|
||
|
||
Spectral ratio measure for best n_components selection
|
||
See :footcite:`NikulinEtAl2011`, Eq. (24).
|
||
|
||
Parameters
|
||
----------
|
||
ssd_sources : array
|
||
Data projected to SSD space.
|
||
|
||
Returns
|
||
-------
|
||
spec_ratio : array, shape (n_channels)
|
||
Array with the sprectal ratio value for each component.
|
||
sorter_spec : array, shape (n_channels)
|
||
Array of indices for sorting spec_ratio.
|
||
|
||
References
|
||
----------
|
||
.. footbibliography::
|
||
"""
|
||
psd, freqs = psd_array_welch(
|
||
ssd_sources, sfreq=self.info["sfreq"], n_fft=self.n_fft
|
||
)
|
||
sig_idx = _time_mask(freqs, *self.freqs_signal)
|
||
noise_idx = _time_mask(freqs, *self.freqs_noise)
|
||
if psd.ndim == 3:
|
||
mean_sig = psd[:, :, sig_idx].mean(axis=2).mean(axis=0)
|
||
mean_noise = psd[:, :, noise_idx].mean(axis=2).mean(axis=0)
|
||
spec_ratio = mean_sig / mean_noise
|
||
else:
|
||
mean_sig = psd[:, sig_idx].mean(axis=1)
|
||
mean_noise = psd[:, noise_idx].mean(axis=1)
|
||
spec_ratio = mean_sig / mean_noise
|
||
sorter_spec = spec_ratio.argsort()[::-1]
|
||
return spec_ratio, sorter_spec
|
||
|
||
def inverse_transform(self):
|
||
"""Not implemented yet."""
|
||
raise NotImplementedError("inverse_transform is not yet available.")
|
||
|
||
def apply(self, X):
|
||
"""Remove selected components from the signal.
|
||
|
||
This procedure will reconstruct M/EEG signals from which the dynamics
|
||
described by the excluded components is subtracted
|
||
(denoised by low-rank factorization).
|
||
See :footcite:`HaufeEtAl2014b` for more information.
|
||
|
||
.. note:: Unlike in other classes with an apply method,
|
||
only NumPy arrays are supported (not instances of MNE objects).
|
||
|
||
Parameters
|
||
----------
|
||
X : array, shape ([n_epochs, ]n_channels, n_times)
|
||
The input data from which to estimate the SSD. Either 2D array
|
||
obtained from continuous data or 3D array obtained from epoched
|
||
data.
|
||
|
||
Returns
|
||
-------
|
||
X : array, shape ([n_epochs, ]n_channels, n_times)
|
||
The processed data.
|
||
"""
|
||
X_ssd = self.transform(X)
|
||
pick_patterns = self.patterns_[self.sorter_spec][: self.n_components].T
|
||
X = pick_patterns @ X_ssd
|
||
return X
|
||
|
||
|
||
def _dimensionality_reduction(cov_signal, cov_noise, info, rank):
|
||
"""Perform dimensionality reduction on the covariance matrices."""
|
||
n_channels = cov_signal.shape[0]
|
||
|
||
# find ranks of covariance matrices
|
||
rank_signal = list(
|
||
compute_rank(
|
||
Covariance(
|
||
cov_signal,
|
||
info.ch_names,
|
||
list(),
|
||
list(),
|
||
0,
|
||
verbose=_verbose_safe_false(),
|
||
),
|
||
rank,
|
||
_handle_default("scalings_cov_rank", None),
|
||
info,
|
||
).values()
|
||
)[0]
|
||
rank_noise = list(
|
||
compute_rank(
|
||
Covariance(
|
||
cov_noise,
|
||
info.ch_names,
|
||
list(),
|
||
list(),
|
||
0,
|
||
verbose=_verbose_safe_false(),
|
||
),
|
||
rank,
|
||
_handle_default("scalings_cov_rank", None),
|
||
info,
|
||
).values()
|
||
)[0]
|
||
rank = np.min([rank_signal, rank_noise]) # should be identical
|
||
|
||
if rank < n_channels:
|
||
eigvals, eigvects = eigh(cov_signal)
|
||
# sort in descending order
|
||
ix = np.argsort(eigvals)[::-1]
|
||
eigvals = eigvals[ix]
|
||
eigvects = eigvects[:, ix]
|
||
# compute rank subspace projection matrix
|
||
rank_proj = np.matmul(
|
||
eigvects[:, :rank], np.eye(rank) * (eigvals[:rank] ** -0.5)
|
||
)
|
||
logger.info(
|
||
"Projecting covariance of %i channels to %i rank subspace"
|
||
% (
|
||
n_channels,
|
||
rank,
|
||
)
|
||
)
|
||
else:
|
||
rank_proj = np.eye(n_channels)
|
||
logger.info("Preserving covariance rank (%i)" % (rank,))
|
||
|
||
# project covariance matrices to rank subspace
|
||
cov_signal = np.matmul(rank_proj.T, np.matmul(cov_signal, rank_proj))
|
||
cov_noise = np.matmul(rank_proj.T, np.matmul(cov_noise, rank_proj))
|
||
return cov_signal, cov_noise, rank_proj
|