# Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. from copy import deepcopy import numpy as np from scipy.fft import fft, fftfreq, ifft from .._fiff.pick import _pick_data_channels, pick_info from ..parallel import parallel_func from ..utils import _validate_type, legacy, logger, verbose from .tfr import AverageTFRArray, _ensure_slice, _get_data def _check_input_st(x_in, n_fft): """Aux function.""" # flatten to 2 D and memorize original shape n_times = x_in.shape[-1] def _is_power_of_two(n): return not (n > 0 and (n & (n - 1))) if n_fft is None or (not _is_power_of_two(n_fft) and n_times > n_fft): # Compute next power of 2 n_fft = 2 ** int(np.ceil(np.log2(n_times))) elif n_fft < n_times: raise ValueError( f"n_fft cannot be smaller than signal size. Got {n_fft} < {n_times}." ) if n_times < n_fft: logger.info( f'The input signal is shorter ({x_in.shape[-1]}) than "n_fft" ({n_fft}). ' "Applying zero padding." ) zero_pad = n_fft - n_times pad_array = np.zeros(x_in.shape[:-1] + (zero_pad,), x_in.dtype) x_in = np.concatenate((x_in, pad_array), axis=-1) else: zero_pad = 0 return x_in, n_fft, zero_pad def _precompute_st_windows(n_samp, start_f, stop_f, sfreq, width): """Precompute stockwell Gaussian windows (in the freq domain).""" tw = fftfreq(n_samp, 1.0 / sfreq) / n_samp tw = np.r_[tw[:1], tw[1:][::-1]] k = width # 1 for classical stowckwell transform f_range = np.arange(start_f, stop_f, 1) windows = np.empty((len(f_range), len(tw)), dtype=np.complex128) for i_f, f in enumerate(f_range): if f == 0.0: window = np.ones(len(tw)) else: window = (f / (np.sqrt(2.0 * np.pi) * k)) * np.exp( -0.5 * (1.0 / k**2.0) * (f**2.0) * tw**2.0 ) window /= window.sum() # normalisation windows[i_f] = fft(window) return windows def _st(x, start_f, windows): """Compute ST based on Ali Moukadem MATLAB code (used in tests).""" from scipy.fft import fft, ifft n_samp = x.shape[-1] ST = np.empty(x.shape[:-1] + (len(windows), n_samp), dtype=np.complex128) # do the work Fx = fft(x) XF = np.concatenate([Fx, Fx], axis=-1) for i_f, window in enumerate(windows): f = start_f + i_f ST[..., i_f, :] = ifft(XF[..., f : f + n_samp] * window) return ST def _st_power_itc(x, start_f, compute_itc, zero_pad, decim, W): """Aux function.""" decim = _ensure_slice(decim) n_samp = x.shape[-1] decim_indices = decim.indices(n_samp - zero_pad) n_out = len(range(*decim_indices)) psd = np.empty((len(W), n_out)) itc = np.empty_like(psd) if compute_itc else None X = fft(x) XX = np.concatenate([X, X], axis=-1) for i_f, window in enumerate(W): f = start_f + i_f ST = ifft(XX[:, f : f + n_samp] * window) TFR = ST[:, slice(*decim_indices)] TFR_abs = np.abs(TFR) TFR_abs[TFR_abs == 0] = 1.0 if compute_itc: TFR /= TFR_abs itc[i_f] = np.abs(np.mean(TFR, axis=0)) TFR_abs *= TFR_abs psd[i_f] = np.mean(TFR_abs, axis=0) return psd, itc def _compute_freqs_st(fmin, fmax, n_fft, sfreq): from scipy.fft import fftfreq freqs = fftfreq(n_fft, 1.0 / sfreq) if fmin is None: fmin = freqs[freqs > 0][0] if fmax is None: fmax = freqs.max() start_f = np.abs(freqs - fmin).argmin() stop_f = np.abs(freqs - fmax).argmin() freqs = freqs[start_f:stop_f] return start_f, stop_f, freqs @verbose def tfr_array_stockwell( data, sfreq, fmin=None, fmax=None, n_fft=None, width=1.0, decim=1, return_itc=False, n_jobs=None, *, verbose=None, ): """Compute power and intertrial coherence using Stockwell (S) transform. Same computation as `~mne.time_frequency.tfr_stockwell`, but operates on :class:`NumPy arrays ` instead of `~mne.Epochs` objects. See :footcite:`Stockwell2007,MoukademEtAl2014,WheatEtAl2010,JonesEtAl2006` for more information. Parameters ---------- data : ndarray, shape (n_epochs, n_channels, n_times) The signal to transform. sfreq : float The sampling frequency. fmin : None, float The minimum frequency to include. If None defaults to the minimum fft frequency greater than zero. fmax : None, float The maximum frequency to include. If None defaults to the maximum fft. n_fft : int | None The length of the windows used for FFT. If None, it defaults to the next power of 2 larger than the signal length. width : float The width of the Gaussian window. If < 1, increased temporal resolution, if > 1, increased frequency resolution. Defaults to 1. (classical S-Transform). %(decim_tfr)s return_itc : bool Return intertrial coherence (ITC) as well as averaged power. %(n_jobs)s %(verbose)s Returns ------- st_power : ndarray The multitaper power of the Stockwell transformed data. The last two dimensions are frequency and time. itc : ndarray The intertrial coherence. Only returned if return_itc is True. freqs : ndarray The frequencies. See Also -------- mne.time_frequency.tfr_stockwell mne.time_frequency.tfr_multitaper mne.time_frequency.tfr_array_multitaper mne.time_frequency.tfr_morlet mne.time_frequency.tfr_array_morlet References ---------- .. footbibliography:: """ _validate_type(data, np.ndarray, "data") if data.ndim != 3: raise ValueError( "data must be 3D with shape (n_epochs, n_channels, n_times), " f"got {data.shape}" ) decim = _ensure_slice(decim) _, n_channels, n_out = data[..., decim].shape data, n_fft_, zero_pad = _check_input_st(data, n_fft) start_f, stop_f, freqs = _compute_freqs_st(fmin, fmax, n_fft_, sfreq) W = _precompute_st_windows(data.shape[-1], start_f, stop_f, sfreq, width) n_freq = stop_f - start_f psd = np.empty((n_channels, n_freq, n_out)) itc = np.empty((n_channels, n_freq, n_out)) if return_itc else None parallel, my_st, n_jobs = parallel_func(_st_power_itc, n_jobs, verbose=verbose) tfrs = parallel( my_st(data[:, c, :], start_f, return_itc, zero_pad, decim, W) for c in range(n_channels) ) for c, (this_psd, this_itc) in enumerate(iter(tfrs)): psd[c] = this_psd if this_itc is not None: itc[c] = this_itc return psd, itc, freqs @legacy(alt='.compute_tfr(method="stockwell", freqs="auto")') @verbose def tfr_stockwell( inst, fmin=None, fmax=None, n_fft=None, width=1.0, decim=1, return_itc=False, n_jobs=None, verbose=None, ): """Compute Time-Frequency Representation (TFR) using Stockwell Transform. Same computation as `~mne.time_frequency.tfr_array_stockwell`, but operates on `~mne.Epochs` objects instead of :class:`NumPy arrays `. See :footcite:`Stockwell2007,MoukademEtAl2014,WheatEtAl2010,JonesEtAl2006` for more information. Parameters ---------- inst : Epochs | Evoked The epochs or evoked object. fmin : None, float The minimum frequency to include. If None defaults to the minimum fft frequency greater than zero. fmax : None, float The maximum frequency to include. If None defaults to the maximum fft. n_fft : int | None The length of the windows used for FFT. If None, it defaults to the next power of 2 larger than the signal length. width : float The width of the Gaussian window. If < 1, increased temporal resolution, if > 1, increased frequency resolution. Defaults to 1. (classical S-Transform). decim : int The decimation factor on the time axis. To reduce memory usage. return_itc : bool Return intertrial coherence (ITC) as well as averaged power. n_jobs : int The number of jobs to run in parallel (over channels). %(verbose)s Returns ------- power : AverageTFR The averaged power. itc : AverageTFR The intertrial coherence. Only returned if return_itc is True. See Also -------- mne.time_frequency.tfr_array_stockwell mne.time_frequency.tfr_multitaper mne.time_frequency.tfr_array_multitaper mne.time_frequency.tfr_morlet mne.time_frequency.tfr_array_morlet Notes ----- .. versionadded:: 0.9.0 References ---------- .. footbibliography:: """ # verbose dec is used b/c subfunctions are verbose data = _get_data(inst, return_itc) picks = _pick_data_channels(inst.info) info = pick_info(inst.info, picks) data = data[:, picks, :] decim = _ensure_slice(decim) power, itc, freqs = tfr_array_stockwell( data, sfreq=info["sfreq"], fmin=fmin, fmax=fmax, n_fft=n_fft, width=width, decim=decim, return_itc=return_itc, n_jobs=n_jobs, ) times = inst.times[decim].copy() nave = len(data) out = AverageTFRArray( info=info, data=power, times=times, freqs=freqs, nave=nave, method="stockwell-power", ) if return_itc: out = ( out, AverageTFRArray( info=deepcopy(info), data=itc, times=times.copy(), freqs=freqs.copy(), nave=nave, method="stockwell-itc", ), ) return out