170 lines
5.1 KiB
Python
170 lines
5.1 KiB
Python
# Authors: The MNE-Python contributors.
|
|
# License: BSD-3-Clause
|
|
# Copyright the MNE-Python contributors.
|
|
|
|
import numpy as np
|
|
|
|
from ..time_frequency.tfr import _compute_tfr
|
|
from ..utils import _check_option, fill_doc, verbose
|
|
from .base import BaseEstimator
|
|
from .mixin import TransformerMixin
|
|
|
|
|
|
@fill_doc
|
|
class TimeFrequency(TransformerMixin, BaseEstimator):
|
|
"""Time frequency transformer.
|
|
|
|
Time-frequency transform of times series along the last axis.
|
|
|
|
Parameters
|
|
----------
|
|
freqs : array-like of float, shape (n_freqs,)
|
|
The frequencies.
|
|
sfreq : float | int, default 1.0
|
|
Sampling frequency of the data.
|
|
method : 'multitaper' | 'morlet', default 'morlet'
|
|
The time-frequency method. 'morlet' convolves a Morlet wavelet.
|
|
'multitaper' uses Morlet wavelets windowed with multiple DPSS
|
|
multitapers.
|
|
n_cycles : float | array of float, default 7.0
|
|
Number of cycles in the Morlet wavelet. Fixed number
|
|
or one per frequency.
|
|
time_bandwidth : float, default None
|
|
If None and method=multitaper, will be set to 4.0 (3 tapers).
|
|
Time x (Full) Bandwidth product. Only applies if
|
|
method == 'multitaper'. The number of good tapers (low-bias) is
|
|
chosen automatically based on this to equal floor(time_bandwidth - 1).
|
|
use_fft : bool, default True
|
|
Use the FFT for convolutions or not.
|
|
decim : int | slice, default 1
|
|
To reduce memory usage, decimation factor after time-frequency
|
|
decomposition.
|
|
If `int`, returns tfr[..., ::decim].
|
|
If `slice`, returns tfr[..., decim].
|
|
|
|
.. note:: Decimation may create aliasing artifacts, yet decimation
|
|
is done after the convolutions.
|
|
|
|
output : str, default 'complex'
|
|
* 'complex' : single trial complex.
|
|
* 'power' : single trial power.
|
|
* 'phase' : single trial phase.
|
|
%(n_jobs)s
|
|
The number of epochs to process at the same time. The parallelization
|
|
is implemented across channels.
|
|
%(verbose)s
|
|
|
|
See Also
|
|
--------
|
|
mne.time_frequency.tfr_morlet
|
|
mne.time_frequency.tfr_multitaper
|
|
"""
|
|
|
|
@verbose
|
|
def __init__(
|
|
self,
|
|
freqs,
|
|
sfreq=1.0,
|
|
method="morlet",
|
|
n_cycles=7.0,
|
|
time_bandwidth=None,
|
|
use_fft=True,
|
|
decim=1,
|
|
output="complex",
|
|
n_jobs=1,
|
|
verbose=None,
|
|
):
|
|
"""Init TimeFrequency transformer."""
|
|
# Check non-average output
|
|
output = _check_option("output", output, ["complex", "power", "phase"])
|
|
|
|
self.freqs = freqs
|
|
self.sfreq = sfreq
|
|
self.method = method
|
|
self.n_cycles = n_cycles
|
|
self.time_bandwidth = time_bandwidth
|
|
self.use_fft = use_fft
|
|
self.decim = decim
|
|
# Check that output is not an average metric (e.g. ITC)
|
|
self.output = output
|
|
self.n_jobs = n_jobs
|
|
self.verbose = verbose
|
|
|
|
def fit_transform(self, X, y=None):
|
|
"""Time-frequency transform of times series along the last axis.
|
|
|
|
Parameters
|
|
----------
|
|
X : array, shape (n_samples, n_channels, n_times)
|
|
The training data samples. The channel dimension can be zero- or
|
|
1-dimensional.
|
|
y : None
|
|
For scikit-learn compatibility purposes.
|
|
|
|
Returns
|
|
-------
|
|
Xt : array, shape (n_samples, n_channels, n_freqs, n_times)
|
|
The time-frequency transform of the data, where n_channels can be
|
|
zero- or 1-dimensional.
|
|
"""
|
|
return self.fit(X, y).transform(X)
|
|
|
|
def fit(self, X, y=None): # noqa: D401
|
|
"""Do nothing (for scikit-learn compatibility purposes).
|
|
|
|
Parameters
|
|
----------
|
|
X : array, shape (n_samples, n_channels, n_times)
|
|
The training data.
|
|
y : array | None
|
|
The target values.
|
|
|
|
Returns
|
|
-------
|
|
self : object
|
|
Return self.
|
|
"""
|
|
return self
|
|
|
|
def transform(self, X):
|
|
"""Time-frequency transform of times series along the last axis.
|
|
|
|
Parameters
|
|
----------
|
|
X : array, shape (n_samples, n_channels, n_times)
|
|
The training data samples. The channel dimension can be zero- or
|
|
1-dimensional.
|
|
|
|
Returns
|
|
-------
|
|
Xt : array, shape (n_samples, n_channels, n_freqs, n_times)
|
|
The time-frequency transform of the data, where n_channels can be
|
|
zero- or 1-dimensional.
|
|
"""
|
|
# Ensure 3-dimensional X
|
|
shape = X.shape[1:-1]
|
|
if not shape:
|
|
X = X[:, np.newaxis, :]
|
|
|
|
# Compute time-frequency
|
|
Xt = _compute_tfr(
|
|
X,
|
|
freqs=self.freqs,
|
|
sfreq=self.sfreq,
|
|
method=self.method,
|
|
n_cycles=self.n_cycles,
|
|
zero_mean=True,
|
|
time_bandwidth=self.time_bandwidth,
|
|
use_fft=self.use_fft,
|
|
decim=self.decim,
|
|
output=self.output,
|
|
n_jobs=self.n_jobs,
|
|
verbose=self.verbose,
|
|
)
|
|
|
|
# Back to original shape
|
|
if not shape:
|
|
Xt = Xt[:, 0, :]
|
|
|
|
return Xt
|