# Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. import numpy as np from ..time_frequency.tfr import _compute_tfr from ..utils import _check_option, fill_doc, verbose from .base import BaseEstimator from .mixin import TransformerMixin @fill_doc class TimeFrequency(TransformerMixin, BaseEstimator): """Time frequency transformer. Time-frequency transform of times series along the last axis. Parameters ---------- freqs : array-like of float, shape (n_freqs,) The frequencies. sfreq : float | int, default 1.0 Sampling frequency of the data. method : 'multitaper' | 'morlet', default 'morlet' The time-frequency method. 'morlet' convolves a Morlet wavelet. 'multitaper' uses Morlet wavelets windowed with multiple DPSS multitapers. n_cycles : float | array of float, default 7.0 Number of cycles in the Morlet wavelet. Fixed number or one per frequency. time_bandwidth : float, default None If None and method=multitaper, will be set to 4.0 (3 tapers). Time x (Full) Bandwidth product. Only applies if method == 'multitaper'. The number of good tapers (low-bias) is chosen automatically based on this to equal floor(time_bandwidth - 1). use_fft : bool, default True Use the FFT for convolutions or not. decim : int | slice, default 1 To reduce memory usage, decimation factor after time-frequency decomposition. If `int`, returns tfr[..., ::decim]. If `slice`, returns tfr[..., decim]. .. note:: Decimation may create aliasing artifacts, yet decimation is done after the convolutions. output : str, default 'complex' * 'complex' : single trial complex. * 'power' : single trial power. * 'phase' : single trial phase. %(n_jobs)s The number of epochs to process at the same time. The parallelization is implemented across channels. %(verbose)s See Also -------- mne.time_frequency.tfr_morlet mne.time_frequency.tfr_multitaper """ @verbose def __init__( self, freqs, sfreq=1.0, method="morlet", n_cycles=7.0, time_bandwidth=None, use_fft=True, decim=1, output="complex", n_jobs=1, verbose=None, ): """Init TimeFrequency transformer.""" # Check non-average output output = _check_option("output", output, ["complex", "power", "phase"]) self.freqs = freqs self.sfreq = sfreq self.method = method self.n_cycles = n_cycles self.time_bandwidth = time_bandwidth self.use_fft = use_fft self.decim = decim # Check that output is not an average metric (e.g. ITC) self.output = output self.n_jobs = n_jobs self.verbose = verbose def fit_transform(self, X, y=None): """Time-frequency transform of times series along the last axis. Parameters ---------- X : array, shape (n_samples, n_channels, n_times) The training data samples. The channel dimension can be zero- or 1-dimensional. y : None For scikit-learn compatibility purposes. Returns ------- Xt : array, shape (n_samples, n_channels, n_freqs, n_times) The time-frequency transform of the data, where n_channels can be zero- or 1-dimensional. """ return self.fit(X, y).transform(X) def fit(self, X, y=None): # noqa: D401 """Do nothing (for scikit-learn compatibility purposes). Parameters ---------- X : array, shape (n_samples, n_channels, n_times) The training data. y : array | None The target values. Returns ------- self : object Return self. """ return self def transform(self, X): """Time-frequency transform of times series along the last axis. Parameters ---------- X : array, shape (n_samples, n_channels, n_times) The training data samples. The channel dimension can be zero- or 1-dimensional. Returns ------- Xt : array, shape (n_samples, n_channels, n_freqs, n_times) The time-frequency transform of the data, where n_channels can be zero- or 1-dimensional. """ # Ensure 3-dimensional X shape = X.shape[1:-1] if not shape: X = X[:, np.newaxis, :] # Compute time-frequency Xt = _compute_tfr( X, freqs=self.freqs, sfreq=self.sfreq, method=self.method, n_cycles=self.n_cycles, zero_mean=True, time_bandwidth=self.time_bandwidth, use_fft=self.use_fft, decim=self.decim, output=self.output, n_jobs=self.n_jobs, verbose=self.verbose, ) # Back to original shape if not shape: Xt = Xt[:, 0, :] return Xt