"""A module which implements the time-frequency estimation. Morlet code inspired by Matlab code from Sheraz Khan & Brainstorm & SPM """ # Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. import inspect from copy import deepcopy from functools import partial import matplotlib.pyplot as plt import numpy as np from scipy.fft import fft, ifft from scipy.signal import argrelmax from .._fiff.meas_info import ContainsMixin, Info from .._fiff.pick import _picks_to_idx, pick_info from ..baseline import _check_baseline, rescale from ..channels.channels import UpdateChannelsMixin from ..channels.layout import _find_topomap_coords, _merge_ch_data, _pair_grad_sensors from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT from ..filter import next_fast_len from ..parallel import parallel_func from ..utils import ( ExtendedTimeMixin, GetEpochsMixin, SizeMixin, _build_data_frame, _check_combine, _check_event_id, _check_fname, _check_method_kwargs, _check_option, _check_pandas_index_arguments, _check_pandas_installed, _check_time_format, _convert_times, _ensure_events, _freq_mask, _import_h5io_funcs, _is_numeric, _pl, _prepare_read_metadata, _prepare_write_metadata, _time_mask, _validate_type, check_fname, copy_doc, copy_function_doc_to_method_doc, fill_doc, legacy, logger, object_diff, repr_html, sizeof_fmt, verbose, warn, ) from ..utils.spectrum import _get_instance_type_string from ..viz.topo import _imshow_tfr, _imshow_tfr_unified, _plot_topo from ..viz.topomap import ( _add_colorbar, _get_pos_outlines, _set_contour_locator, plot_tfr_topomap, plot_topomap, ) from ..viz.utils import ( _make_combine_callable, _prepare_joint_axes, _set_title_multiple_electrodes, _setup_cmap, _setup_vmin_vmax, add_background_image, figure_nobar, plt_show, ) from .multitaper import dpss_windows, tfr_array_multitaper from .spectrum import EpochsSpectrum @fill_doc def morlet(sfreq, freqs, n_cycles=7.0, sigma=None, zero_mean=False): """Compute Morlet wavelets for the given frequency range. Parameters ---------- sfreq : float The sampling Frequency. freqs : float | array-like, shape (n_freqs,) Frequencies to compute Morlet wavelets for. n_cycles : float | array-like, shape (n_freqs,) Number of cycles. Can be a fixed number (float) or one per frequency (array-like). sigma : float, default None It controls the width of the wavelet ie its temporal resolution. If sigma is None the temporal resolution is adapted with the frequency like for all wavelet transform. The higher the frequency the shorter is the wavelet. If sigma is fixed the temporal resolution is fixed like for the short time Fourier transform and the number of oscillations increases with the frequency. zero_mean : bool, default False Make sure the wavelet has a mean of zero. Returns ------- Ws : list of ndarray | ndarray The wavelets time series. If ``freqs`` was a float, a single ndarray is returned instead of a list of ndarray. See Also -------- mne.time_frequency.fwhm Notes ----- %(morlet_reference)s %(fwhm_morlet_notes)s References ---------- .. footbibliography:: Examples -------- Let's show a simple example of the relationship between ``n_cycles`` and the FWHM using :func:`mne.time_frequency.fwhm`: .. plot:: import numpy as np import matplotlib.pyplot as plt from mne.time_frequency import morlet, fwhm sfreq, freq, n_cycles = 1000., 10, 7 # i.e., 700 ms this_fwhm = fwhm(freq, n_cycles) wavelet = morlet(sfreq=sfreq, freqs=freq, n_cycles=n_cycles) M, w = len(wavelet), n_cycles # convert to SciPy convention s = w * sfreq / (2 * freq * np.pi) # from SciPy docs _, ax = plt.subplots(layout="constrained") colors = dict(real="#66CCEE", imag="#EE6677") t = np.arange(-M // 2 + 1, M // 2 + 1) / sfreq for kind in ('real', 'imag'): ax.plot( t, getattr(wavelet, kind), label=kind, color=colors[kind], ) ax.plot(t, np.abs(wavelet), label=f'abs', color='k', lw=1., zorder=6) half_max = np.max(np.abs(wavelet)) / 2. ax.plot([-this_fwhm / 2., this_fwhm / 2.], [half_max, half_max], color='k', linestyle='-', label='FWHM', zorder=6) ax.legend(loc='upper right') ax.set(xlabel='Time (s)', ylabel='Amplitude') """ # noqa: E501 Ws = list() n_cycles = np.array(n_cycles, float).ravel() freqs = np.array(freqs, float) if np.any(freqs <= 0): raise ValueError("all frequencies in 'freqs' must be greater than 0.") if (n_cycles.size != 1) and (n_cycles.size != len(freqs)): raise ValueError("n_cycles should be fixed or defined for each frequency.") _check_option("freqs.ndim", freqs.ndim, [0, 1]) singleton = freqs.ndim == 0 if singleton: freqs = freqs[np.newaxis] for k, f in enumerate(freqs): if len(n_cycles) != 1: this_n_cycles = n_cycles[k] else: this_n_cycles = n_cycles[0] # sigma_t is the stddev of gaussian window in the time domain; can be # scale-dependent or fixed across freqs if sigma is None: sigma_t = this_n_cycles / (2.0 * np.pi * f) else: sigma_t = this_n_cycles / (2.0 * np.pi * sigma) # time vector. We go 5 standard deviations out to make sure we're # *very* close to zero at the ends. We also make sure that there's a # sample at exactly t=0 t = np.arange(0.0, 5.0 * sigma_t, 1.0 / sfreq) t = np.r_[-t[::-1], t[1:]] oscillation = np.exp(2.0 * 1j * np.pi * f * t) if zero_mean: # this offset is equivalent to the κ_σ term in Wikipedia's # equations, and satisfies the "admissibility criterion" for CWTs real_offset = np.exp(-2 * (np.pi * f * sigma_t) ** 2) oscillation -= real_offset gaussian_envelope = np.exp(-(t**2) / (2.0 * sigma_t**2)) W = oscillation * gaussian_envelope # the scaling factor here is proportional to what is used in # Tallon-Baudry 1997: (sigma_t*sqrt(pi))^(-1/2). It yields a wavelet # with norm sqrt(2) for the full wavelet / norm 1 for the real part W /= np.sqrt(0.5) * np.linalg.norm(W.ravel()) Ws.append(W) if singleton: Ws = Ws[0] return Ws def fwhm(freq, n_cycles): """Compute the full-width half maximum of a Morlet wavelet. Uses the formula from :footcite:t:`Cohen2019`. Parameters ---------- freq : float The oscillation frequency of the wavelet. n_cycles : float The duration of the wavelet, expressed as the number of oscillation cycles. Returns ------- fwhm : float The full-width half maximum of the wavelet. Notes ----- .. versionadded:: 1.3 References ---------- .. footbibliography:: """ return n_cycles * np.sqrt(2 * np.log(2)) / (np.pi * freq) def _make_dpss( sfreq, freqs, n_cycles=7.0, time_bandwidth=4.0, zero_mean=False, return_weights=False, ): """Compute DPSS tapers for the given frequency range. Parameters ---------- sfreq : float The sampling frequency. freqs : ndarray, shape (n_freqs,) The frequencies in Hz. n_cycles : float | ndarray, shape (n_freqs,), default 7. The number of cycles globally or for each frequency. time_bandwidth : float, default 4.0 Time x Bandwidth product. The number of good tapers (low-bias) is chosen automatically based on this to equal floor(time_bandwidth - 1). Default is 4.0, giving 3 good tapers. zero_mean : bool | None, , default False Make sure the wavelet has a mean of zero. return_weights : bool Whether to return the concentration weights. Returns ------- Ws : list of array The wavelets time series. """ Ws = list() freqs = np.array(freqs) if np.any(freqs <= 0): raise ValueError("all frequencies in 'freqs' must be greater than 0.") if time_bandwidth < 2.0: raise ValueError("time_bandwidth should be >= 2.0 for good tapers") n_taps = int(np.floor(time_bandwidth - 1)) n_cycles = np.atleast_1d(n_cycles) if n_cycles.size != 1 and n_cycles.size != len(freqs): raise ValueError("n_cycles should be fixed or defined for each frequency.") for m in range(n_taps): Wm = list() for k, f in enumerate(freqs): if len(n_cycles) != 1: this_n_cycles = n_cycles[k] else: this_n_cycles = n_cycles[0] t_win = this_n_cycles / float(f) t = np.arange(0.0, t_win, 1.0 / sfreq) # Making sure wavelets are centered before tapering oscillation = np.exp(2.0 * 1j * np.pi * f * (t - t_win / 2.0)) # Get dpss tapers tapers, conc = dpss_windows( t.shape[0], time_bandwidth / 2.0, n_taps, sym=False ) Wk = oscillation * tapers[m] if zero_mean: # to make it zero mean real_offset = Wk.mean() Wk -= real_offset Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel()) Wm.append(Wk) Ws.append(Wm) if return_weights: return Ws, conc return Ws # Low level convolution def _get_nfft(wavelets, X, use_fft=True, check=True): n_times = X.shape[-1] max_size = max(w.size for w in wavelets) if max_size > n_times: msg = ( f"At least one of the wavelets ({max_size}) is longer than the " f"signal ({n_times}). Consider using a longer signal or " "shorter wavelets." ) if check: if use_fft: warn(msg, UserWarning) else: raise ValueError(msg) nfft = n_times + max_size - 1 nfft = next_fast_len(nfft) # 2 ** int(np.ceil(np.log2(nfft))) return nfft def _cwt_gen(X, Ws, *, fsize=0, mode="same", decim=1, use_fft=True): """Compute cwt with fft based convolutions or temporal convolutions. Parameters ---------- X : array of shape (n_signals, n_times) The data. Ws : list of array Wavelets time series. fsize : int FFT length. mode : {'full', 'valid', 'same'} See numpy.convolve. decim : int | slice, default 1 To reduce memory usage, decimation factor after time-frequency decomposition. If `int`, returns tfr[..., ::decim]. If `slice`, returns tfr[..., decim]. .. note:: Decimation may create aliasing artifacts. use_fft : bool, default True Use the FFT for convolutions or not. Returns ------- out : array, shape (n_signals, n_freqs, n_time_decim) The time-frequency transform of the signals. """ _check_option("mode", mode, ["same", "valid", "full"]) decim = _ensure_slice(decim) X = np.asarray(X) # Precompute wavelets for given frequency range to save time _, n_times = X.shape n_times_out = X[:, decim].shape[1] n_freqs = len(Ws) # precompute FFTs of Ws if use_fft: fft_Ws = np.empty((n_freqs, fsize), dtype=np.complex128) for i, W in enumerate(Ws): fft_Ws[i] = fft(W, fsize) # Make generator looping across signals tfr = np.zeros((n_freqs, n_times_out), dtype=np.complex128) for x in X: if use_fft: fft_x = fft(x, fsize) # Loop across wavelets for ii, W in enumerate(Ws): if use_fft: ret = ifft(fft_x * fft_Ws[ii])[: n_times + W.size - 1] else: # Work around multarray.correlate->OpenBLAS bug on ppc64le # ret = np.correlate(x, W, mode=mode) ret = np.convolve(x, W.real, mode=mode) + 1j * np.convolve( x, W.imag, mode=mode ) # Center and decimate decomposition if mode == "valid": sz = int(abs(W.size - n_times)) + 1 offset = (n_times - sz) // 2 this_slice = slice(offset // decim.step, (offset + sz) // decim.step) if use_fft: ret = _centered(ret, sz) tfr[ii, this_slice] = ret[decim] elif mode == "full" and not use_fft: start = (W.size - 1) // 2 end = len(ret) - (W.size // 2) ret = ret[start:end] tfr[ii, :] = ret[decim] else: if use_fft: ret = _centered(ret, n_times) tfr[ii, :] = ret[decim] yield tfr # Loop of convolution: single trial def _compute_tfr( epoch_data, freqs, sfreq=1.0, method="morlet", n_cycles=7.0, zero_mean=None, time_bandwidth=None, use_fft=True, decim=1, output="complex", n_jobs=None, *, verbose=None, ): """Compute time-frequency transforms. Parameters ---------- epoch_data : array of shape (n_epochs, n_channels, n_times) The epochs.default ``'complex'`` freqs : array-like of floats, shape (n_freqs) The frequencies. sfreq : float | int, default 1.0 Sampling frequency of the data. method : 'multitaper' | 'morlet', default 'morlet' The time-frequency method. 'morlet' convolves a Morlet wavelet. 'multitaper' uses complex exponentials windowed with multiple DPSS tapers. n_cycles : float | array of float, default 7.0 Number of cycles in the wavelet. Fixed number or one per frequency. zero_mean : bool | None, default None None means True for method='multitaper' and False for method='morlet'. If True, make sure the wavelets have a mean of zero. time_bandwidth : float, default None If None and method=multitaper, will be set to 4.0 (3 tapers). Time x (Full) Bandwidth product. Only applies if method == 'multitaper'. The number of good tapers (low-bias) is chosen automatically based on this to equal floor(time_bandwidth - 1). use_fft : bool, default True Use the FFT for convolutions or not. decim : int | slice, default 1 To reduce memory usage, decimation factor after time-frequency decomposition. If `int`, returns tfr[..., ::decim]. If `slice`, returns tfr[..., decim]. .. note:: Decimation may create aliasing artifacts, yet decimation is done after the convolutions. output : str * 'complex' (default) : single trial complex. * 'power' : single trial power. * 'phase' : single trial phase. * 'avg_power' : average of single trial power. * 'itc' : inter-trial coherence. * 'avg_power_itc' : average of single trial power and inter-trial coherence across trials. %(n_jobs)s The number of epochs to process at the same time. The parallelization is implemented across channels. %(verbose)s Returns ------- out : array Time frequency transform of epoch_data. If output is in ['complex', 'phase', 'power'], then shape of ``out`` is ``(n_epochs, n_chans, n_freqs, n_times)``, else it is ``(n_chans, n_freqs, n_times)``. However, using multitaper method and output ``'complex'`` or ``'phase'`` results in shape of ``out`` being ``(n_epochs, n_chans, n_tapers, n_freqs, n_times)``. If output is ``'avg_power_itc'``, the real values in the ``output`` contain average power' and the imaginary values contain the ITC: ``out = avg_power + i * itc``. """ # Check data epoch_data = np.asarray(epoch_data) if epoch_data.ndim != 3: raise ValueError( "epoch_data must be of shape (n_epochs, n_chans, " f"n_times), got {epoch_data.shape}" ) # Check params freqs, sfreq, zero_mean, n_cycles, time_bandwidth, decim = _check_tfr_param( freqs, sfreq, method, zero_mean, n_cycles, time_bandwidth, use_fft, decim, output, ) decim = _ensure_slice(decim) if (freqs > sfreq / 2.0).any(): raise ValueError( "Cannot compute freq above Nyquist freq of the data " f"({sfreq / 2.0:0.1f} Hz), got {freqs.max():0.1f} Hz" ) # We decimate *after* decomposition, so we need to create our kernels # for the original sfreq if method == "morlet": W = morlet(sfreq, freqs, n_cycles=n_cycles, zero_mean=zero_mean) Ws = [W] # to have same dimensionality as the 'multitaper' case elif method == "multitaper": Ws = _make_dpss( sfreq, freqs, n_cycles=n_cycles, time_bandwidth=time_bandwidth, zero_mean=zero_mean, ) # Check wavelets if len(Ws[0][0]) > epoch_data.shape[2]: raise ValueError( "At least one of the wavelets is longer than the " "signal. Use a longer signal or shorter wavelets." ) # Initialize output n_freqs = len(freqs) n_tapers = len(Ws) n_epochs, n_chans, n_times = epoch_data[:, :, decim].shape if output in ("power", "phase", "avg_power", "itc"): dtype = np.float64 elif output in ("complex", "avg_power_itc"): # avg_power_itc is stored as power + 1i * itc to keep a # simple dimensionality dtype = np.complex128 if ("avg_" in output) or ("itc" in output): out = np.empty((n_chans, n_freqs, n_times), dtype) elif output in ["complex", "phase"] and method == "multitaper": out = np.empty((n_chans, n_tapers, n_epochs, n_freqs, n_times), dtype) else: out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype) # Parallel computation all_Ws = sum([list(W) for W in Ws], list()) _get_nfft(all_Ws, epoch_data, use_fft) parallel, my_cwt, n_jobs = parallel_func(_time_frequency_loop, n_jobs) # Parallelization is applied across channels. tfrs = parallel( my_cwt(channel, Ws, output, use_fft, "same", decim, method) for channel in epoch_data.transpose(1, 0, 2) ) # FIXME: to avoid overheads we should use np.array_split() for channel_idx, tfr in enumerate(tfrs): out[channel_idx] = tfr if ("avg_" not in output) and ("itc" not in output): # This is to enforce that the first dimension is for epochs if output in ["complex", "phase"] and method == "multitaper": out = out.transpose(2, 0, 1, 3, 4) else: out = out.transpose(1, 0, 2, 3) return out def _check_tfr_param( freqs, sfreq, method, zero_mean, n_cycles, time_bandwidth, use_fft, decim, output ): """Aux. function to _compute_tfr to check the params validity.""" # Check freqs if not isinstance(freqs, (list, np.ndarray)): raise ValueError(f"freqs must be an array-like, got {type(freqs)} instead.") freqs = np.asarray(freqs, dtype=float) if freqs.ndim != 1: raise ValueError( f"freqs must be of shape (n_freqs,), got {np.array(freqs.shape)} " "instead." ) # Check sfreq if not isinstance(sfreq, (float, int)): raise ValueError(f"sfreq must be a float or an int, got {type(sfreq)} instead.") sfreq = float(sfreq) # Default zero_mean = True if multitaper else False zero_mean = method == "multitaper" if zero_mean is None else zero_mean if not isinstance(zero_mean, bool): raise ValueError( f"zero_mean should be of type bool, got {type(zero_mean)}. instead" ) freqs = np.asarray(freqs) # Check n_cycles if isinstance(n_cycles, (int, float)): n_cycles = float(n_cycles) elif isinstance(n_cycles, (list, np.ndarray)): n_cycles = np.array(n_cycles) if len(n_cycles) != len(freqs): raise ValueError( "n_cycles must be a float or an array of length " f"{len(freqs)} frequencies, got {len(n_cycles)} cycles instead." ) else: raise ValueError( f"n_cycles must be a float or an array, got {type(n_cycles)} instead." ) # Check time_bandwidth if (method == "morlet") and (time_bandwidth is not None): raise ValueError('time_bandwidth only applies to "multitaper" method.') elif method == "multitaper": time_bandwidth = 4.0 if time_bandwidth is None else float(time_bandwidth) # Check use_fft if not isinstance(use_fft, bool): raise ValueError(f"use_fft must be a boolean, got {type(use_fft)} instead.") # Check decim if isinstance(decim, int): decim = slice(None, None, decim) if not isinstance(decim, slice): raise ValueError( f"decim must be an integer or a slice, got {type(decim)} instead." ) # Check output _check_option( "output", output, ["complex", "power", "phase", "avg_power_itc", "avg_power", "itc"], ) _check_option("method", method, ["multitaper", "morlet"]) return freqs, sfreq, zero_mean, n_cycles, time_bandwidth, decim def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None): """Aux. function to _compute_tfr. Loops time-frequency transform across wavelets and epochs. Parameters ---------- X : array, shape (n_epochs, n_times) The epochs data of a single channel. Ws : list, shape (n_tapers, n_wavelets, n_times) The wavelets. output : str * 'complex' : single trial complex. * 'power' : single trial power. * 'phase' : single trial phase. * 'avg_power' : average of single trial power. * 'itc' : inter-trial coherence. * 'avg_power_itc' : average of single trial power and inter-trial coherence across trials. use_fft : bool Use the FFT for convolutions or not. mode : {'full', 'valid', 'same'} See numpy.convolve. decim : slice The decimation slice: e.g. power[:, decim] method : str | None Used only for multitapering to create tapers dimension in the output if ``output in ['complex', 'phase']``. """ # Set output type dtype = np.float64 if output in ["complex", "avg_power_itc"]: dtype = np.complex128 # Init outputs decim = _ensure_slice(decim) n_tapers = len(Ws) n_epochs, n_times = X[:, decim].shape n_freqs = len(Ws[0]) if ("avg_" in output) or ("itc" in output): tfrs = np.zeros((n_freqs, n_times), dtype=dtype) elif output in ["complex", "phase"] and method == "multitaper": tfrs = np.zeros((n_tapers, n_epochs, n_freqs, n_times), dtype=dtype) else: tfrs = np.zeros((n_epochs, n_freqs, n_times), dtype=dtype) # Loops across tapers. for taper_idx, W in enumerate(Ws): # No need to check here, it's done earlier (outside parallel part) nfft = _get_nfft(W, X, use_fft, check=False) coefs = _cwt_gen(X, W, fsize=nfft, mode=mode, decim=decim, use_fft=use_fft) # Inter-trial phase locking is apparently computed per taper... if "itc" in output: plf = np.zeros((n_freqs, n_times), dtype=np.complex128) # Loop across epochs for epoch_idx, tfr in enumerate(coefs): # Transform complex values if output in ["power", "avg_power"]: tfr = (tfr * tfr.conj()).real # power elif output == "phase": tfr = np.angle(tfr) elif output == "avg_power_itc": tfr_abs = np.abs(tfr) plf += tfr / tfr_abs # phase tfr = tfr_abs**2 # power elif output == "itc": plf += tfr / np.abs(tfr) # phase continue # not need to stack anything else than plf # Stack or add if ("avg_" in output) or ("itc" in output): tfrs += tfr elif output in ["complex", "phase"] and method == "multitaper": tfrs[taper_idx, epoch_idx] += tfr else: tfrs[epoch_idx] += tfr # Compute inter trial coherence if output == "avg_power_itc": tfrs += 1j * np.abs(plf) elif output == "itc": tfrs += np.abs(plf) # Normalization of average metrics if ("avg_" in output) or ("itc" in output): tfrs /= n_epochs # Normalization by number of taper if n_tapers > 1 and output not in ["complex", "phase"]: tfrs /= n_tapers return tfrs @fill_doc def cwt(X, Ws, use_fft=True, mode="same", decim=1): """Compute time-frequency decomposition with continuous wavelet transform. Parameters ---------- X : array, shape (n_signals, n_times) The signals. Ws : list of array Wavelets time series. use_fft : bool Use FFT for convolutions. Defaults to True. mode : 'same' | 'valid' | 'full' Convention for convolution. 'full' is currently not implemented with ``use_fft=False``. Defaults to ``'same'``. %(decim_tfr)s Returns ------- tfr : array, shape (n_signals, n_freqs, n_times) The time-frequency decompositions. See Also -------- mne.time_frequency.tfr_morlet : Compute time-frequency decomposition with Morlet wavelets. """ nfft = _get_nfft(Ws, X, use_fft) return _cwt_array(X, Ws, nfft, mode, decim, use_fft) def _cwt_array(X, Ws, nfft, mode, decim, use_fft): decim = _ensure_slice(decim) coefs = _cwt_gen(X, Ws, fsize=nfft, mode=mode, decim=decim, use_fft=use_fft) n_signals, n_times = X[:, decim].shape tfrs = np.empty((n_signals, len(Ws), n_times), dtype=np.complex128) for k, tfr in enumerate(coefs): tfrs[k] = tfr return tfrs def _tfr_aux( method, inst, freqs, decim, return_itc, picks, average, output, **tfr_params ): from ..epochs import BaseEpochs kwargs = dict( method=method, freqs=freqs, picks=picks, decim=decim, output=output, **tfr_params, ) if isinstance(inst, BaseEpochs): kwargs.update(average=average, return_itc=return_itc) elif average: logger.info("inst is Evoked, setting `average=False`") average = False if average and output == "complex": raise ValueError('output must be "power" if average=True') if not average and return_itc: raise ValueError("Inter-trial coherence is not supported with average=False") return inst.compute_tfr(**kwargs) @legacy(alt='.compute_tfr(method="morlet")') @verbose def tfr_morlet( inst, freqs, n_cycles, use_fft=False, return_itc=True, decim=1, n_jobs=None, picks=None, zero_mean=True, average=True, output="power", verbose=None, ): """Compute Time-Frequency Representation (TFR) using Morlet wavelets. Same computation as `~mne.time_frequency.tfr_array_morlet`, but operates on `~mne.Epochs` or `~mne.Evoked` objects instead of :class:`NumPy arrays `. Parameters ---------- inst : Epochs | Evoked The epochs or evoked object. %(freqs_tfr_array)s %(n_cycles_tfr)s use_fft : bool, default False The fft based convolution or not. return_itc : bool, default True Return inter-trial coherence (ITC) as well as averaged power. Must be ``False`` for evoked data. %(decim_tfr)s %(n_jobs)s picks : array-like of int | None, default None The indices of the channels to decompose. If None, all available good data channels are decomposed. zero_mean : bool, default True Make sure the wavelet has a mean of zero. .. versionadded:: 0.13.0 %(average_tfr)s output : str Can be ``"power"`` (default) or ``"complex"``. If ``"complex"``, then ``average`` must be ``False``. .. versionadded:: 0.15.0 %(verbose)s Returns ------- power : AverageTFR | EpochsTFR The averaged or single-trial power. itc : AverageTFR | EpochsTFR The inter-trial coherence (ITC). Only returned if return_itc is True. See Also -------- mne.time_frequency.tfr_array_morlet mne.time_frequency.tfr_multitaper mne.time_frequency.tfr_array_multitaper mne.time_frequency.tfr_stockwell mne.time_frequency.tfr_array_stockwell Notes ----- %(morlet_reference)s %(temporal_window_tfr_intro)s %(temporal_window_tfr_morlet_notes)s See :func:`mne.time_frequency.morlet` for more information about the Morlet wavelet. References ---------- .. footbibliography:: """ tfr_params = dict( n_cycles=n_cycles, n_jobs=n_jobs, use_fft=use_fft, zero_mean=zero_mean, output=output, ) return _tfr_aux( "morlet", inst, freqs, decim, return_itc, picks, average, **tfr_params ) @verbose def tfr_array_morlet( data, sfreq, freqs, n_cycles=7.0, zero_mean=True, use_fft=True, decim=1, output="complex", n_jobs=None, *, verbose=None, ): """Compute Time-Frequency Representation (TFR) using Morlet wavelets. Same computation as `~mne.time_frequency.tfr_morlet`, but operates on :class:`NumPy arrays ` instead of `~mne.Epochs` objects. Parameters ---------- data : array of shape (n_epochs, n_channels, n_times) The epochs. sfreq : float | int Sampling frequency of the data. %(freqs_tfr_array)s %(n_cycles_tfr)s zero_mean : bool | None If True, make sure the wavelets have a mean of zero. default False. .. versionchanged:: 1.8 The default will change from ``zero_mean=False`` in 1.6 to ``True`` in 1.8. use_fft : bool Use the FFT for convolutions or not. default True. %(decim_tfr)s output : str, default ``'complex'`` * ``'complex'`` : single trial complex. * ``'power'`` : single trial power. * ``'phase'`` : single trial phase. * ``'avg_power'`` : average of single trial power. * ``'itc'`` : inter-trial coherence. * ``'avg_power_itc'`` : average of single trial power and inter-trial coherence across trials. %(n_jobs)s The number of epochs to process at the same time. The parallelization is implemented across channels. Default 1. %(verbose)s Returns ------- out : array Time frequency transform of ``data``. - if ``output in ('complex', 'phase', 'power')``, array of shape ``(n_epochs, n_chans, n_freqs, n_times)`` - else, array of shape ``(n_chans, n_freqs, n_times)`` If ``output`` is ``'avg_power_itc'``, the real values in ``out`` contain the average power and the imaginary values contain the ITC: :math:`out = power_{avg} + i * itc`. See Also -------- mne.time_frequency.tfr_morlet mne.time_frequency.tfr_multitaper mne.time_frequency.tfr_array_multitaper mne.time_frequency.tfr_stockwell mne.time_frequency.tfr_array_stockwell Notes ----- %(morlet_reference)s %(temporal_window_tfr_intro)s %(temporal_window_tfr_morlet_notes)s .. versionadded:: 0.14.0 References ---------- .. footbibliography:: """ return _compute_tfr( epoch_data=data, freqs=freqs, sfreq=sfreq, method="morlet", n_cycles=n_cycles, zero_mean=zero_mean, time_bandwidth=None, use_fft=use_fft, decim=decim, output=output, n_jobs=n_jobs, verbose=verbose, ) @legacy(alt='.compute_tfr(method="multitaper")') @verbose def tfr_multitaper( inst, freqs, n_cycles, time_bandwidth=4.0, use_fft=True, return_itc=True, decim=1, n_jobs=None, picks=None, average=True, *, verbose=None, ): """Compute Time-Frequency Representation (TFR) using DPSS tapers. Same computation as :func:`~mne.time_frequency.tfr_array_multitaper`, but operates on :class:`~mne.Epochs` or :class:`~mne.Evoked` objects instead of :class:`NumPy arrays `. Parameters ---------- inst : Epochs | Evoked The epochs or evoked object. %(freqs_tfr_array)s %(n_cycles_tfr)s %(time_bandwidth_tfr)s use_fft : bool, default True The fft based convolution or not. return_itc : bool, default True Return inter-trial coherence (ITC) as well as averaged (or single-trial) power. %(decim_tfr)s %(n_jobs)s %(picks_good_data)s %(average_tfr)s %(verbose)s Returns ------- power : AverageTFR | EpochsTFR The averaged or single-trial power. itc : AverageTFR | EpochsTFR The inter-trial coherence (ITC). Only returned if return_itc is True. See Also -------- mne.time_frequency.tfr_array_multitaper mne.time_frequency.tfr_stockwell mne.time_frequency.tfr_array_stockwell mne.time_frequency.tfr_morlet mne.time_frequency.tfr_array_morlet Notes ----- %(temporal_window_tfr_intro)s %(temporal_window_tfr_multitaper_notes)s %(time_bandwidth_tfr_notes)s .. versionadded:: 0.9.0 """ from ..epochs import EpochsArray from ..evoked import Evoked tfr_params = dict( n_cycles=n_cycles, n_jobs=n_jobs, use_fft=use_fft, zero_mean=True, time_bandwidth=time_bandwidth, ) if isinstance(inst, Evoked) and not average: # convert AverageTFR to EpochsTFR for backwards compatibility inst = EpochsArray(inst.data[np.newaxis], inst.info, tmin=inst.tmin, proj=False) return _tfr_aux( method="multitaper", inst=inst, freqs=freqs, decim=decim, return_itc=return_itc, picks=picks, average=average, output="power", **tfr_params, ) # TFR(s) class @fill_doc class BaseTFR(ContainsMixin, UpdateChannelsMixin, SizeMixin, ExtendedTimeMixin): """Base class for RawTFR, EpochsTFR, and AverageTFR (for type checking only). .. note:: This class should not be instantiated directly; it is provided in the public API only for type-checking purposes (e.g., ``isinstance(my_obj, BaseTFR)``). To create TFR objects, use the ``.compute_tfr()`` methods on :class:`~mne.io.Raw`, :class:`~mne.Epochs`, or :class:`~mne.Evoked`, or use the constructors listed below under "See Also". Parameters ---------- inst : instance of Raw, Epochs, or Evoked The data from which to compute the time-frequency representation. %(method_tfr)s %(freqs_tfr)s %(tmin_tmax_psd)s %(picks_good_data_noref)s %(proj_psd)s %(decim_tfr)s %(n_jobs)s %(reject_by_annotation_tfr)s %(verbose)s %(method_kw_tfr)s See Also -------- mne.time_frequency.RawTFR mne.time_frequency.RawTFRArray mne.time_frequency.EpochsTFR mne.time_frequency.EpochsTFRArray mne.time_frequency.AverageTFR mne.time_frequency.AverageTFRArray """ def __init__( self, inst, method, freqs, tmin, tmax, picks, proj, *, decim, n_jobs, reject_by_annotation=None, verbose=None, **method_kw, ): from ..epochs import BaseEpochs from ._stockwell import tfr_array_stockwell # triage reading from file if isinstance(inst, dict): self.__setstate__(inst) return if method is None or freqs is None: problem = [ f"{k}=None" for k, v in dict(method=method, freqs=freqs).items() if v is None ] # TODO when py3.11 is min version, replace if/elif/else block with # classname = inspect.currentframe().f_back.f_code.co_qualname.split(".")[0] _varnames = inspect.currentframe().f_back.f_code.co_varnames if "BaseRaw" in _varnames: classname = "RawTFR" elif "Evoked" in _varnames: classname = "AverageTFR" else: assert "BaseEpochs" in _varnames and "Evoked" not in _varnames classname = "EpochsTFR" # end TODO raise ValueError( f'{classname} got unsupported parameter value{_pl(problem)} ' f'{" and ".join(problem)}.' ) # shim for tfr_array_morlet deprecation warning (TODO: remove after 1.7 release) if method == "morlet": method_kw.setdefault("zero_mean", True) # check method valid_methods = ["morlet", "multitaper"] if isinstance(inst, BaseEpochs): valid_methods.append("stockwell") method = _check_option("method", method, valid_methods) # for stockwell, `tmin, tmax` already added to `method_kw` by calling method, # and `freqs` vector has been pre-computed if method != "stockwell": method_kw.update(freqs=freqs) # ↓↓↓ if constructor called directly, prevents key error method_kw.setdefault("output", "power") self._freqs = np.asarray(freqs, dtype=np.float64) del freqs # check validity of kwargs manually to save compute time if any are invalid tfr_funcs = dict( morlet=tfr_array_morlet, multitaper=tfr_array_multitaper, stockwell=tfr_array_stockwell, ) _check_method_kwargs(tfr_funcs[method], method_kw, msg=f'TFR method "{method}"') self._tfr_func = partial(tfr_funcs[method], **method_kw) # apply proj if desired if proj: inst = inst.copy().apply_proj() self.inst = inst # prep picks and add the info object. bads and non-data channels are dropped by # _picks_to_idx() so we update the info accordingly: self._picks = _picks_to_idx(inst.info, picks, "data", with_ref_meg=False) self.info = pick_info(inst.info, sel=self._picks, copy=True) # assign some attributes self._method = method self._inst_type = type(inst) self._baseline = None self.preload = True # needed for __getitem__, never False for TFRs # self._dims may also get updated by child classes self._dims = ["channel", "freq", "time"] self._needs_taper_dim = method == "multitaper" and method_kw["output"] in ( "complex", "phase", ) if self._needs_taper_dim: self._dims.insert(1, "taper") self._dims = tuple(self._dims) # get the instance data. time_mask = _time_mask(inst.times, tmin, tmax, sfreq=self.sfreq) get_instance_data_kw = dict(time_mask=time_mask) if reject_by_annotation is not None: get_instance_data_kw.update(reject_by_annotation=reject_by_annotation) data = self._get_instance_data(**get_instance_data_kw) # compute the TFR self._decim = _ensure_slice(decim) self._raw_times = inst.times[time_mask] self._compute_tfr(data, n_jobs, verbose) self._update_epoch_attributes() # "apply" decim to the rest of the object (data is decimated in _compute_tfr) with self.info._unlock(): self.info["sfreq"] /= self._decim.step _decim_times = inst.times[self._decim] _decim_time_mask = _time_mask(_decim_times, tmin, tmax, sfreq=self.sfreq) self._raw_times = _decim_times[_decim_time_mask].copy() self._set_times(self._raw_times) self._decim = 1 # record data type (for repr and html_repr). ITC handled in the calling method. if method == "stockwell": self._data_type = "Power Estimates" else: data_types = dict( power="Power Estimates", avg_power="Average Power Estimates", avg_power_itc="Average Power Estimates", phase="Phase", complex="Complex Amplitude", ) self._data_type = data_types[method_kw["output"]] # check for correct shape and bad values. `tfr_array_stockwell` doesn't take kw # `output` so it may be missing here, so use `.get()` negative_ok = method_kw.get("output", "") in ("complex", "phase") # if method_kw.get("output", None) in ("phase", "complex"): # raise RuntimeError self._check_values(negative_ok=negative_ok) # we don't need these anymore, and they make save/load harder del self._picks del self._tfr_func del self._needs_taper_dim del self._shape # calculated from self._data henceforth del self.inst # save memory def __abs__(self): """Return the absolute value.""" tfr = self.copy() tfr.data = np.abs(tfr.data) return tfr @fill_doc def __add__(self, other): """Add two TFR instances. %(__add__tfr)s """ self._check_compatibility(other) out = self.copy() out.data += other.data return out @fill_doc def __iadd__(self, other): """Add a TFR instance to another, in-place. %(__iadd__tfr)s """ self._check_compatibility(other) self.data += other.data return self @fill_doc def __sub__(self, other): """Subtract two TFR instances. %(__sub__tfr)s """ self._check_compatibility(other) out = self.copy() out.data -= other.data return out @fill_doc def __isub__(self, other): """Subtract a TFR instance from another, in-place. %(__isub__tfr)s """ self._check_compatibility(other) self.data -= other.data return self @fill_doc def __mul__(self, num): """Multiply a TFR instance by a scalar. %(__mul__tfr)s """ out = self.copy() out.data *= num return out @fill_doc def __imul__(self, num): """Multiply a TFR instance by a scalar, in-place. %(__imul__tfr)s """ self.data *= num return self @fill_doc def __truediv__(self, num): """Divide a TFR instance by a scalar. %(__truediv__tfr)s """ out = self.copy() out.data /= num return out @fill_doc def __itruediv__(self, num): """Divide a TFR instance by a scalar, in-place. %(__itruediv__tfr)s """ self.data /= num return self def __eq__(self, other): """Test equivalence of two TFR instances.""" return object_diff(vars(self), vars(other)) == "" def __getstate__(self): """Prepare object for serialization.""" return dict( method=self.method, data=self._data, sfreq=self.sfreq, dims=self._dims, freqs=self.freqs, times=self.times, inst_type_str=_get_instance_type_string(self), data_type=self._data_type, info=self.info, baseline=self._baseline, decim=self._decim, ) def __setstate__(self, state): """Unpack from serialized format.""" from ..epochs import Epochs from ..evoked import Evoked from ..io import Raw defaults = dict( method="unknown", dims=("epoch", "channel", "freq", "time")[-state["data"].ndim :], baseline=None, decim=1, data_type="TFR", inst_type_str="Unknown", ) defaults.update(**state) self._method = defaults["method"] self._data = defaults["data"] self._freqs = np.asarray(defaults["freqs"], dtype=np.float64) self._dims = defaults["dims"] self._raw_times = np.asarray(defaults["times"], dtype=np.float64) self._baseline = defaults["baseline"] self.info = Info(**defaults["info"]) self._data_type = defaults["data_type"] self._decim = defaults["decim"] self.preload = True self._set_times(self._raw_times) # Handle instance type. Prior to gh-11282, Raw was not a possibility so if # `inst_type_str` is missing it must be Epochs or Evoked unknown_class = Epochs if "epoch" in self._dims else Evoked inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Unknown=unknown_class) self._inst_type = inst_types[defaults["inst_type_str"]] # sanity check data/freqs/times/info agreement self._check_state() def __repr__(self): """Build string representation of the TFR object.""" inst_type_str = _get_instance_type_string(self) nave = f" (nave={self.nave})" if hasattr(self, "nave") else "" # shape & dimension names dims = " × ".join( [f"{size} {dim}s" for size, dim in zip(self.shape, self._dims)] ) freq_range = f"{self.freqs[0]:0.1f} - {self.freqs[-1]:0.1f} Hz" time_range = f"{self.times[0]:0.2f} - {self.times[-1]:0.2f} s" return ( f"<{self._data_type} from {inst_type_str}{nave}, " f"{self.method} method | {dims}, {freq_range}, {time_range}, " f"{sizeof_fmt(self._size)}>" ) @repr_html def _repr_html_(self, caption=None): """Build HTML representation of the TFR object.""" from ..html_templates import _get_html_template inst_type_str = _get_instance_type_string(self) nave = getattr(self, "nave", 0) t = _get_html_template("repr", "tfr.html.jinja") t = t.render(tfr=self, inst_type=inst_type_str, nave=nave, caption=caption) return t def _check_compatibility(self, other): """Check compatibility of two TFR instances, in preparation for arithmetic.""" operation = inspect.currentframe().f_back.f_code.co_name.strip("_") if operation.startswith("i"): operation = operation[1:] msg = f"Cannot {operation} the two TFR instances: {{}} do not match{{}}." extra = "" if not isinstance(other, type(self)): problem = "types" extra = f" (self is {type(self)}, other is {type(other)})" elif not self.times.shape == other.times.shape or np.any( self.times != other.times ): problem = "times" elif not self.freqs.shape == other.freqs.shape or np.any( self.freqs != other.freqs ): problem = "freqs" else: # should be OK return raise RuntimeError(msg.format(problem, extra)) def _check_state(self): """Check data/freqs/times/info agreement during __setstate__.""" msg = "{} axis of data ({}) doesn't match {} attribute ({})" n_chan_info = len(self.info["chs"]) n_chan = self._data.shape[self._dims.index("channel")] n_freq = self._data.shape[self._dims.index("freq")] n_time = self._data.shape[self._dims.index("time")] if n_chan_info != n_chan: msg = msg.format("Channel", n_chan, "info", n_chan_info) elif n_freq != len(self.freqs): msg = msg.format("Frequency", n_freq, "freqs", self.freqs.size) elif n_time != len(self.times): msg = msg.format("Time", n_time, "times", self.times.size) else: return raise ValueError(msg) def _check_values(self, negative_ok=False): """Check TFR results for correct shape and bad values.""" assert len(self._dims) == self._data.ndim assert self._data.shape == self._shape # Check for implausible power values: take min() across all but the channel axis # TODO: should this be more fine-grained (report "chan X in epoch Y")? ch_dim = self._dims.index("channel") dims = np.arange(self._data.ndim).tolist() dims.pop(ch_dim) negative_values = self._data.min(axis=tuple(dims)) < 0 if negative_values.any() and not negative_ok: chs = np.array(self.ch_names)[negative_values].tolist() s = _pl(negative_values.sum()) warn( f"Negative value in time-frequency decomposition for channel{s} " f'{", ".join(chs)}', UserWarning, ) def _compute_tfr(self, data, n_jobs, verbose): result = self._tfr_func( data, self.sfreq, decim=self._decim, n_jobs=n_jobs, verbose=verbose, ) # assign ._data and maybe ._itc # tfr_array_stockwell always returns ITC (sometimes it's None) if self.method == "stockwell": self._data, self._itc, freqs = result assert np.array_equal(self._freqs, freqs) elif self._tfr_func.keywords.get("output", "").endswith("_itc"): self._data, self._itc = result.real, result.imag else: self._data = result # remove fake "epoch" dimension if self.method != "stockwell" and _get_instance_type_string(self) != "Epochs": self._data = np.squeeze(self._data, axis=0) # this is *expected* shape, it gets asserted later in _check_values() # (and then deleted afterwards) expected_shape = [ len(self.ch_names), len(self.freqs), len(self._raw_times[self._decim]), # don't use self.times, not set yet ] # deal with the "taper" dimension if self._needs_taper_dim: expected_shape.insert(1, self._data.shape[1]) self._shape = tuple(expected_shape) @verbose def _onselect( self, eclick, erelease, picks=None, exclude="bads", combine="mean", baseline=None, mode=None, cmap=None, source_plot_joint=False, topomap_args=None, verbose=None, ): """Respond to rectangle selector in TFR image plots with a topomap plot.""" if abs(eclick.x - erelease.x) < 0.1 or abs(eclick.y - erelease.y) < 0.1: return t_range = (min(eclick.xdata, erelease.xdata), max(eclick.xdata, erelease.xdata)) f_range = (min(eclick.ydata, erelease.ydata), max(eclick.ydata, erelease.ydata)) # snap to nearest measurement point t_idx = np.abs(self.times - np.atleast_2d(t_range).T).argmin(axis=1) f_idx = np.abs(self.freqs - np.atleast_2d(f_range).T).argmin(axis=1) tmin, tmax = self.times[t_idx] fmin, fmax = self.freqs[f_idx] # immutable → mutable default if topomap_args is None: topomap_args = dict() topomap_args.setdefault("cmap", cmap) topomap_args.setdefault("vlim", (None, None)) # figure out which channel types we're dealing with types = list() if "eeg" in self: types.append("eeg") if "mag" in self: types.append("mag") if "grad" in self: grad_picks = _pair_grad_sensors( self.info, topomap_coords=False, raise_error=False ) if len(grad_picks) > 1: types.append("grad") elif len(types) == 0: logger.info( "Need at least 2 gradiometer pairs to plot a gradiometer topomap." ) return # Don't draw a figure for nothing. fig = figure_nobar() t_range = f"{tmin:.3f}" if tmin == tmax else f"{tmin:.3f} - {tmax:.3f}" f_range = f"{fmin:.2f}" if fmin == fmax else f"{fmin:.2f} - {fmax:.2f}" fig.suptitle(f"{t_range} s,\n{f_range} Hz") if source_plot_joint: ax = fig.add_subplot() data, times, freqs = self.get_data( picks=picks, exclude=exclude, return_times=True, return_freqs=True ) # merge grads before baselining (makes ERDs visible) ch_types = np.array(self.get_channel_types(unique=True)) ch_type = ch_types.item() # will error if there are more than one data, pos = _merge_if_grads( data=data, info=self.info, ch_type=ch_type, sphere=topomap_args.get("sphere"), combine=combine, ) # baseline and crop data, *_ = _prep_data_for_plot( data, times, freqs, tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, baseline=baseline, mode=mode, verbose=verbose, ) # average over times and freqs data = data.mean((-2, -1)) im, _ = plot_topomap(data, pos, axes=ax, show=False, **topomap_args) _add_colorbar(ax, im, topomap_args["cmap"], title="AU") plt_show(fig=fig) else: for idx, ch_type in enumerate(types): ax = fig.add_subplot(1, len(types), idx + 1) plot_tfr_topomap( self, ch_type=ch_type, tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, baseline=baseline, mode=mode, axes=ax, **topomap_args, ) ax.set_title(ch_type) def _update_epoch_attributes(self): # overwritten in EpochsTFR; adds things needed for to_data_frame and __getitem__ pass @property def _detrend_picks(self): """Provide compatibility with __iter__.""" return list() @property def baseline(self): """Start and end of the baseline period (in seconds).""" return self._baseline @property def ch_names(self): """The channel names.""" return self.info["ch_names"] @property def data(self): """The time-frequency-resolved power estimates.""" return self._data @data.setter def data(self, data): self._data = data @property def freqs(self): """The frequencies at which power estimates were computed.""" return self._freqs @property def method(self): """The method used to compute the time-frequency power estimates.""" return self._method @property def sfreq(self): """Sampling frequency of the data.""" return self.info["sfreq"] @property def shape(self): """Data shape.""" return self._data.shape @property def times(self): """The time points present in the data (in seconds).""" return self._times_readonly @fill_doc def crop(self, tmin=None, tmax=None, fmin=None, fmax=None, include_tmax=True): """Crop data to a given time interval in place. Parameters ---------- %(tmin_tmax_psd)s fmin : float | None Lowest frequency of selection in Hz. .. versionadded:: 0.18.0 fmax : float | None Highest frequency of selection in Hz. .. versionadded:: 0.18.0 %(include_tmax)s Returns ------- %(inst_tfr)s The modified instance. """ super().crop(tmin=tmin, tmax=tmax, include_tmax=include_tmax) if fmin is not None or fmax is not None: freq_mask = _freq_mask( self.freqs, sfreq=self.info["sfreq"], fmin=fmin, fmax=fmax ) else: freq_mask = slice(None) self._freqs = self.freqs[freq_mask] # Deal with broadcasting (boolean arrays do not broadcast, but indices # do, so we need to convert freq_mask to make use of broadcasting) if isinstance(freq_mask, np.ndarray): freq_mask = np.where(freq_mask)[0] self._data = self._data[..., freq_mask, :] return self def copy(self): """Return copy of the TFR instance. Returns ------- %(inst_tfr)s A copy of the object. """ return deepcopy(self) @verbose def apply_baseline(self, baseline, mode="mean", verbose=None): """Baseline correct the data. Parameters ---------- %(baseline_rescale)s How baseline is computed is determined by the ``mode`` parameter. mode : 'mean' | 'ratio' | 'logratio' | 'percent' | 'zscore' | 'zlogratio' Perform baseline correction by - subtracting the mean of baseline values ('mean') - dividing by the mean of baseline values ('ratio') - dividing by the mean of baseline values and taking the log ('logratio') - subtracting the mean of baseline values followed by dividing by the mean of baseline values ('percent') - subtracting the mean of baseline values and dividing by the standard deviation of baseline values ('zscore') - dividing by the mean of baseline values, taking the log, and dividing by the standard deviation of log baseline values ('zlogratio') %(verbose)s Returns ------- %(inst_tfr)s The modified instance. """ self._baseline = _check_baseline(baseline, times=self.times, sfreq=self.sfreq) rescale(self.data, self.times, self.baseline, mode, copy=False, verbose=verbose) return self @fill_doc def get_data( self, picks=None, exclude="bads", fmin=None, fmax=None, tmin=None, tmax=None, return_times=False, return_freqs=False, ): """Get time-frequency data in NumPy array format. Parameters ---------- %(picks_good_data_noref)s %(exclude_spectrum_get_data)s %(fmin_fmax_tfr)s %(tmin_tmax_psd)s return_times : bool Whether to return the time values for the requested time range. Default is ``False``. return_freqs : bool Whether to return the frequency bin values for the requested frequency range. Default is ``False``. Returns ------- data : array The requested data in a NumPy array. times : array The time values for the requested data range. Only returned if ``return_times`` is ``True``. freqs : array The frequency values for the requested data range. Only returned if ``return_freqs`` is ``True``. Notes ----- Returns a copy of the underlying data (not a view). """ tmin = self.times[0] if tmin is None else tmin tmax = self.times[-1] if tmax is None else tmax fmin = 0 if fmin is None else fmin fmax = np.inf if fmax is None else fmax picks = _picks_to_idx( self.info, picks, "data_or_ica", exclude=exclude, with_ref_meg=False ) fmin_idx = np.searchsorted(self.freqs, fmin) fmax_idx = np.searchsorted(self.freqs, fmax, side="right") tmin_idx = np.searchsorted(self.times, tmin) tmax_idx = np.searchsorted(self.times, tmax, side="right") freq_picks = np.arange(fmin_idx, fmax_idx) time_picks = np.arange(tmin_idx, tmax_idx) freq_axis = self._dims.index("freq") time_axis = self._dims.index("time") chan_axis = self._dims.index("channel") # normally there's a risk of np.take reducing array dimension if there # were only one channel or frequency selected, but `_picks_to_idx` # and np.arange both always return arrays, so we're safe; the result # will always have the same `ndim` as it started with. data = ( self._data.take(picks, chan_axis) .take(freq_picks, freq_axis) .take(time_picks, time_axis) ) out = [data] if return_times: times = self._raw_times[tmin_idx:tmax_idx] out.append(times) if return_freqs: freqs = self._freqs[fmin_idx:fmax_idx] out.append(freqs) if not return_times and not return_freqs: return out[0] return tuple(out) @verbose def plot( self, picks=None, *, exclude=(), tmin=None, tmax=None, fmin=0.0, fmax=np.inf, baseline=None, mode="mean", dB=False, combine=None, layout=None, # TODO deprecate? not used in orig implementation either yscale="auto", vmin=None, vmax=None, vlim=(None, None), cnorm=None, cmap=None, colorbar=True, title=None, # don't deprecate this one; has (useful) option title="auto" mask=None, mask_style=None, mask_cmap="Greys", mask_alpha=0.1, axes=None, show=True, verbose=None, ): """Plot TFRs as two-dimensional time-frequency images. Parameters ---------- %(picks_good_data)s %(exclude_spectrum_plot)s %(tmin_tmax_psd)s %(fmin_fmax_tfr)s %(baseline_rescale)s How baseline is computed is determined by the ``mode`` parameter. %(mode_tfr_plot)s %(dB_spectrum_plot)s %(combine_tfr_plot)s .. versionchanged:: 1.3 Added support for ``callable``. %(layout_spectrum_plot_topo)s %(yscale_tfr_plot)s .. versionadded:: 0.14.0 %(vmin_vmax_tfr_plot)s %(vlim_tfr_plot)s %(cnorm)s .. versionadded:: 0.24 %(cmap_topomap)s %(colorbar)s %(title_tfr_plot)s %(mask_tfr_plot)s .. versionadded:: 0.16.0 %(mask_style_tfr_plot)s .. versionadded:: 0.17 %(mask_cmap_tfr_plot)s .. versionadded:: 0.17 %(mask_alpha_tfr_plot)s .. versionadded:: 0.16.0 %(axes_tfr_plot)s %(show)s %(verbose)s Returns ------- figs : list of instances of matplotlib.figure.Figure A list of figures containing the time-frequency power. """ # deprecations vlim = _warn_deprecated_vmin_vmax(vlim, vmin, vmax) # the rectangle selector plots topomaps, which needs all channels uncombined, # so we keep a reference to that state here, and (because the topomap plotting # function wants an AverageTFR) update it with `comment` and `nave` values in # case we started out with a singleton EpochsTFR or RawTFR initial_state = self.__getstate__() initial_state.setdefault("comment", "") initial_state.setdefault("nave", 1) # `_picks_to_idx` also gets done inside `get_data()`` below, but we do it here # because we need the indices later idx_picks = _picks_to_idx( self.info, picks, "data_or_ica", exclude=exclude, with_ref_meg=False ) pick_names = np.array(self.ch_names)[idx_picks].tolist() # for titles ch_types = self.get_channel_types(idx_picks) # get data arrays data, times, freqs = self.get_data( picks=idx_picks, exclude=(), return_times=True, return_freqs=True ) # pass tmin/tmax here ↓↓↓, not here ↑↑↑; we want to crop *after* baselining data, times, freqs = _prep_data_for_plot( data, times, freqs, tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, baseline=baseline, mode=mode, dB=dB, verbose=verbose, ) # shape ch_axis = self._dims.index("channel") freq_axis = self._dims.index("freq") time_axis = self._dims.index("time") want_shape = list(self.shape) want_shape[ch_axis] = len(idx_picks) if combine is None else 1 want_shape[freq_axis] = len(freqs) # in case there was fmin/fmax cropping want_shape[time_axis] = len(times) # in case there was tmin/tmax cropping want_shape = tuple(want_shape) # combine combine_was_none = combine is None combine = _make_combine_callable( combine, axis=ch_axis, valid=("mean", "rms"), keepdims=True ) try: data = combine(data) # no need to copy; get_data() never returns a view except Exception as e: msg = ( "Something went wrong with the callable passed to 'combine'; see " "traceback." ) raise ValueError(msg) from e # call succeeded, check type and shape mismatch = False if not isinstance(data, np.ndarray): mismatch = "type" extra = "" elif data.shape not in (want_shape, want_shape[1:]): mismatch = "shape" extra = f" of shape {data.shape}" if mismatch: raise RuntimeError( f"Wrong {mismatch} yielded by callable passed to 'combine'. Make sure " "your function takes a single argument (an array of shape " "(n_channels, n_freqs, n_times)) and returns an array of shape " f"(n_freqs, n_times); yours yielded: {type(data)}{extra}." ) # restore singleton collapsed axis (removed by user-provided callable): # (n_freqs, n_times) → (1, n_freqs, n_times) if data.shape == (len(freqs), len(times)): data = data[np.newaxis] assert data.shape == want_shape # cmap handling. power may be negative depending on baseline strategy so set # `norm` empirically — but only if user didn't set limits explicitly. norm = False if vlim == (None, None) else data.min() >= 0.0 vmin, vmax = _setup_vmin_vmax(data, *vlim, norm=norm) cmap = _setup_cmap(cmap, norm=norm) # prepare figure(s) if axes is None: figs = [plt.figure(layout="constrained") for _ in range(data.shape[0])] axes = [fig.add_subplot() for fig in figs] elif isinstance(axes, plt.Axes): figs = [axes.get_figure()] axes = [axes] elif isinstance(axes, np.ndarray): # allow plotting into a grid of axes figs = [ax.get_figure() for ax in axes.flat] elif hasattr(axes, "__iter__") and len(axes): figs = [ax.get_figure() for ax in axes] else: raise ValueError( f"axes must be None, Axes, or list/array of Axes, got {type(axes)}" ) if len(axes) != data.shape[0]: raise RuntimeError( f"Mismatch between picked channels ({data.shape[0]}) and axes " f"({len(axes)}); there must be one axes for each picked channel." ) # check if we're being called from within plot_joint(). If so, get the # `topomap_args` from the calling context and pass it to the onselect handler. # (we need 2 `f_back` here because of the verbose decorator) calling_frame = inspect.currentframe().f_back.f_back source_plot_joint = calling_frame.f_code.co_name == "plot_joint" topomap_args = ( dict() if not source_plot_joint else calling_frame.f_locals.get("topomap_args", dict()) ) # plot for ix, _fig in enumerate(figs): # restrict the onselect instance to the channel type of the picks used in # the image plot uniq_types = np.unique(ch_types) ch_type = None if len(uniq_types) > 1 else uniq_types.item() this_tfr = AverageTFR(inst=initial_state).pick(ch_type, verbose=verbose) _onselect_callback = partial( this_tfr._onselect, picks=None, # already restricted the picks in `this_tfr` exclude=(), baseline=baseline, mode=mode, cmap=cmap, source_plot_joint=source_plot_joint, topomap_args=topomap_args, ) # draw the image plot _imshow_tfr( ax=axes[ix], tfr=data[[ix]], ch_idx=0, tmin=times[0], tmax=times[-1], vmin=vmin, vmax=vmax, onselect=_onselect_callback, ylim=None, freq=freqs, x_label="Time (s)", y_label="Frequency (Hz)", colorbar=colorbar, cmap=cmap, yscale=yscale, mask=mask, mask_style=mask_style, mask_cmap=mask_cmap, mask_alpha=mask_alpha, cnorm=cnorm, ) # handle title. automatic title is: # f"{Baselined} {power} ({ch_name})" or # f"{Baselined} {power} ({combination} of {N} {ch_type}s)" if title == "auto": if combine_was_none: # one plot per channel which_chs = pick_names[ix] elif len(pick_names) == 1: # there was only one pick anyway which_chs = pick_names[0] else: # one plot for all chs combined which_chs = _set_title_multiple_electrodes( None, combine, pick_names, all_=True, ch_type=ch_type ) _prefix = "Power" if baseline is None else "Baselined power" _title = f"{_prefix} ({which_chs})" else: _title = title _fig.suptitle(_title) plt_show(show) return figs @verbose def plot_joint( self, *, timefreqs=None, picks=None, exclude=(), combine="mean", tmin=None, tmax=None, fmin=None, fmax=None, baseline=None, mode="mean", dB=False, yscale="auto", vmin=None, vmax=None, vlim=(None, None), cnorm=None, cmap=None, colorbar=True, title=None, # TODO consider deprecating this one, or adding an "auto" option show=True, topomap_args=None, image_args=None, verbose=None, ): """Plot TFRs as a two-dimensional image with topomap highlights. Parameters ---------- %(timefreqs)s %(picks_good_data)s %(exclude_psd)s Default is an empty :class:`tuple` which includes all channels. %(combine_tfr_plot_joint)s .. versionchanged:: 1.3 Added support for ``callable``. %(tmin_tmax_psd)s %(fmin_fmax_tfr)s %(baseline_rescale)s How baseline is computed is determined by the ``mode`` parameter. %(mode_tfr_plot)s %(dB_tfr_plot_topo)s %(yscale_tfr_plot)s %(vmin_vmax_tfr_plot)s %(vlim_tfr_plot_joint)s %(cnorm)s %(cmap_tfr_plot_topo)s %(colorbar_tfr_plot_joint)s %(title_none)s %(show)s %(topomap_args)s %(image_args)s %(verbose)s Returns ------- fig : matplotlib.figure.Figure The figure containing the topography. Notes ----- %(notes_timefreqs_tfr_plot_joint)s .. versionadded:: 0.16.0 """ from matplotlib import ticker from matplotlib.patches import ConnectionPatch # deprecations vlim = _warn_deprecated_vmin_vmax(vlim, vmin, vmax) # handle recursion picks = _picks_to_idx( self.info, picks, "data_or_ica", exclude=exclude, with_ref_meg=False ) all_ch_types = np.array(self.get_channel_types()) uniq_ch_types = sorted(set(all_ch_types[picks])) if len(uniq_ch_types) > 1: msg = "Multiple channel types selected, returning one figure per type." logger.info(msg) figs = list() for this_type in uniq_ch_types: this_picks = np.intersect1d( picks, np.nonzero(np.isin(all_ch_types, this_type))[0], assume_unique=True, ) # TODO might be nice to not "copy first, then pick"; alternative might # be to subset the data with `this_picks` and then construct the "copy" # using __getstate__ and __setstate__ _tfr = self.copy().pick(this_picks) figs.append( _tfr.plot_joint( timefreqs=timefreqs, picks=None, baseline=baseline, mode=mode, tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, vlim=vlim, cmap=cmap, dB=dB, colorbar=colorbar, show=False, title=title, yscale=yscale, combine=combine, exclude=(), topomap_args=topomap_args, verbose=verbose, ) ) return figs else: ch_type = uniq_ch_types[0] # handle defaults _validate_type(combine, ("str", "callable"), item_name="combine") # no `None` image_args = dict() if image_args is None else image_args topomap_args = dict() if topomap_args is None else topomap_args.copy() # make sure if topomap_args["ch_type"] is set, it matches what is in `self.info` topomap_args.setdefault("ch_type", ch_type) if topomap_args["ch_type"] != ch_type: raise ValueError( f"topomap_args['ch_type'] is {topomap_args['ch_type']} which does not " f"match the channel type present in the object ({ch_type})." ) # some necessary defaults topomap_args.setdefault("outlines", "head") topomap_args.setdefault("contours", 6) # don't pass these: topomap_args.pop("axes", None) topomap_args.pop("show", None) topomap_args.pop("colorbar", None) # get the time/freq limits of the image plot, to make sure requested annotation # times/freqs are in range _, times, freqs = self.get_data( picks=picks, exclude=(), tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, return_times=True, return_freqs=True, ) # validate requested annotation times and freqs timefreqs = _get_timefreqs(self, timefreqs) valid_timefreqs = dict() while timefreqs: (_time, _freq), (t_win, f_win) = timefreqs.popitem() # convert to half-windows t_win /= 2 f_win /= 2 # make sure the times / freqs are in-bounds msg = ( "Requested {} exceeds the range of the data ({}). Choose different " "`timefreqs`." ) if (times > _time).all() or (times < _time).all(): _var = f"time point ({_time:0.3f} s)" _range = f"{times[0]:0.3f} - {times[-1]:0.3f} s" raise ValueError(msg.format(_var, _range)) elif (freqs > _freq).all() or (freqs < _freq).all(): _var = f"frequency ({_freq:0.1f} Hz)" _range = f"{freqs[0]:0.1f} - {freqs[-1]:0.1f} Hz" raise ValueError(msg.format(_var, _range)) # snap the times/freqs to the nearest point we have an estimate for, and # store the validated points if t_win == 0: _time = times[np.argmin(np.abs(times - _time))] if f_win == 0: _freq = freqs[np.argmin(np.abs(freqs - _freq))] valid_timefreqs[(_time, _freq)] = (t_win, f_win) # prep data for topomaps (unlike image plot, must include all channels of the # current ch_type). Don't pass tmin/tmax here (crop later after baselining) topomap_picks = _picks_to_idx(self.info, ch_type) data, times, freqs = self.get_data( picks=topomap_picks, exclude=(), return_times=True, return_freqs=True ) # merge grads before baselining (makes ERDS visible) info = pick_info(self.info, sel=topomap_picks, copy=True) data, pos = _merge_if_grads( data=data, info=info, ch_type=ch_type, sphere=topomap_args.get("sphere"), combine=combine, ) # loop over intended topomap locations, to find one vlim that works for all. tf_array = np.array(list(valid_timefreqs)) # each row is [time, freq] tf_array = tf_array[tf_array[:, 0].argsort()] # sort by time _vmin, _vmax = (np.inf, -np.inf) topomap_arrays = list() topomap_titles = list() for _time, _freq in tf_array: # reduce data to the range of interest in the TF plane (i.e., finally crop) t_win, f_win = valid_timefreqs[(_time, _freq)] _tmin, _tmax = np.array([-1, 1]) * t_win + _time _fmin, _fmax = np.array([-1, 1]) * f_win + _freq _data, *_ = _prep_data_for_plot( data, times, freqs, tmin=_tmin, tmax=_tmax, fmin=_fmin, fmax=_fmax, baseline=baseline, mode=mode, verbose=verbose, ) _data = _data.mean(axis=(-1, -2)) # avg over times and freqs topomap_arrays.append(_data) _vmin = min(_data.min(), _vmin) _vmax = max(_data.max(), _vmax) # construct topopmap subplot title t_pm = "" if t_win == 0 else f" ± {t_win:0.2f}" f_pm = "" if f_win == 0 else f" ± {f_win:0.1f}" _title = f"{_time:0.2f}{t_pm} s,\n{_freq:0.1f}{f_pm} Hz" topomap_titles.append(_title) # handle cmap. Power may be negative depending on baseline strategy so set # `norm` empirically. vmin/vmax will be handled separately within the `plot()` # call for the image plot. norm = np.min(topomap_arrays) >= 0.0 cmap = _setup_cmap(cmap, norm=norm) topomap_args.setdefault("cmap", cmap[0]) # prevent interactive cbar # finalize topomap vlims and compute contour locations. # By passing `data=None` here ↓↓↓↓ we effectively assert vmin & vmax aren't None _vlim = _setup_vmin_vmax(data=None, vmin=_vmin, vmax=_vmax, norm=norm) topomap_args.setdefault("vlim", _vlim) locator, topomap_args["contours"] = _set_contour_locator( *topomap_args["vlim"], topomap_args["contours"] ) # initialize figure and do the image plot. `self.plot()` needed to wait to be # called until after `topomap_args` was fully populated --- we don't pass the # dict through to `self.plot()` explicitly here, but we do "reach back" and get # it if it's needed by the interactive rectangle selector. fig, image_ax, topomap_axes = _prepare_joint_axes(len(valid_timefreqs)) fig = self.plot( picks=picks, exclude=(), tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, baseline=baseline, mode=mode, dB=dB, combine=combine, yscale=yscale, vlim=vlim, cnorm=cnorm, cmap=cmap, colorbar=False, title=title, # mask, mask_style, mask_cmap, mask_alpha axes=image_ax, show=False, verbose=verbose, **image_args, )[0] # [0] because `.plot()` always returns a list # now, actually plot the topomaps for ax, title, _data in zip(topomap_axes, topomap_titles, topomap_arrays): ax.set_title(title) plot_topomap(_data, pos, axes=ax, show=False, **topomap_args) # draw colorbar if colorbar: cbar = fig.colorbar(ax.images[0]) cbar.locator = ticker.MaxNLocator(nbins=5) if locator is None else locator cbar.update_ticks() # draw the connection lines between time-frequency image and topoplots for (time_, freq_), topo_ax in zip(tf_array, topomap_axes): con = ConnectionPatch( xyA=[time_, freq_], xyB=[0.5, 0], coordsA="data", coordsB="axes fraction", axesA=image_ax, axesB=topo_ax, color="grey", linestyle="-", linewidth=1.5, alpha=0.66, zorder=1, clip_on=False, ) fig.add_artist(con) plt_show(show) return fig @verbose def plot_topo( self, picks=None, baseline=None, mode="mean", tmin=None, tmax=None, fmin=None, fmax=None, vmin=None, # TODO deprecate in favor of `vlim` (needs helper func refactor) vmax=None, layout=None, cmap="RdBu_r", title=None, # don't deprecate; topo titles aren't standard (color, size, just.) dB=False, colorbar=True, layout_scale=0.945, show=True, border="none", fig_facecolor="k", fig_background=None, font_color="w", yscale="auto", verbose=None, ): """Plot a TFR image for each channel in a sensor layout arrangement. Parameters ---------- %(picks_good_data)s %(baseline_rescale)s How baseline is computed is determined by the ``mode`` parameter. %(mode_tfr_plot)s %(tmin_tmax_psd)s %(fmin_fmax_tfr)s %(vmin_vmax_tfr_plot_topo)s %(layout_spectrum_plot_topo)s %(cmap_tfr_plot_topo)s %(title_none)s %(dB_tfr_plot_topo)s %(colorbar)s %(layout_scale)s %(show)s %(border_topo)s %(fig_facecolor)s %(fig_background)s %(font_color)s %(yscale_tfr_plot)s %(verbose)s Returns ------- fig : matplotlib.figure.Figure The figure containing the topography. """ # convenience vars times = self.times.copy() freqs = self.freqs data = self.data info = self.info info, data = _prepare_picks(info, data, picks, axis=0) del picks # TODO this is the only remaining call to _preproc_tfr; should be refactored # (to use _prep_data_for_plot?) data, times, freqs, vmin, vmax = _preproc_tfr( data, times, freqs, tmin, tmax, fmin, fmax, mode, baseline, vmin, vmax, dB, info["sfreq"], ) if layout is None: from mne import find_layout layout = find_layout(self.info) onselect_callback = partial(self._onselect, baseline=baseline, mode=mode) click_fun = partial( _imshow_tfr, tfr=data, freq=freqs, yscale=yscale, cmap=(cmap, True), onselect=onselect_callback, ) imshow = partial( _imshow_tfr_unified, tfr=data, freq=freqs, cmap=cmap, onselect=onselect_callback, ) fig = _plot_topo( info=info, times=times, show_func=imshow, click_func=click_fun, layout=layout, colorbar=colorbar, vmin=vmin, vmax=vmax, cmap=cmap, layout_scale=layout_scale, title=title, border=border, x_label="Time (s)", y_label="Frequency (Hz)", fig_facecolor=fig_facecolor, font_color=font_color, unified=True, img=True, ) add_background_image(fig, fig_background) plt_show(show) return fig @copy_function_doc_to_method_doc(plot_tfr_topomap) def plot_topomap( self, tmin=None, tmax=None, fmin=0.0, fmax=np.inf, *, ch_type=None, baseline=None, mode="mean", sensors=True, show_names=False, mask=None, mask_params=None, contours=6, outlines="head", sphere=None, image_interp=_INTERPOLATION_DEFAULT, extrapolate=_EXTRAPOLATE_DEFAULT, border=_BORDER_DEFAULT, res=64, size=2, cmap=None, vlim=(None, None), cnorm=None, colorbar=True, cbar_fmt="%1.1e", units=None, axes=None, show=True, ): return plot_tfr_topomap( self, tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, ch_type=ch_type, baseline=baseline, mode=mode, sensors=sensors, show_names=show_names, mask=mask, mask_params=mask_params, contours=contours, outlines=outlines, sphere=sphere, image_interp=image_interp, extrapolate=extrapolate, border=border, res=res, size=size, cmap=cmap, vlim=vlim, cnorm=cnorm, colorbar=colorbar, cbar_fmt=cbar_fmt, units=units, axes=axes, show=show, ) @verbose def save(self, fname, *, overwrite=False, verbose=None): """Save time-frequency data to disk (in HDF5 format). Parameters ---------- fname : path-like Path of file to save to, which should end with ``-tfr.h5`` or ``-tfr.hdf5``. %(overwrite)s %(verbose)s See Also -------- mne.time_frequency.read_tfrs """ _, write_hdf5 = _import_h5io_funcs() check_fname(fname, "time-frequency object", (".h5", ".hdf5")) fname = _check_fname(fname, overwrite=overwrite, verbose=verbose) out = self.__getstate__() if "metadata" in out: out["metadata"] = _prepare_write_metadata(out["metadata"]) write_hdf5(fname, out, overwrite=overwrite, title="mnepython", slash="replace") @verbose def to_data_frame( self, picks=None, index=None, long_format=False, time_format=None, *, verbose=None, ): """Export data in tabular structure as a pandas DataFrame. Channels are converted to columns in the DataFrame. By default, additional columns ``'time'``, ``'freq'``, ``'epoch'``, and ``'condition'`` (epoch event description) are added, unless ``index`` is not ``None`` (in which case the columns specified in ``index`` will be used to form the DataFrame's index instead). ``'epoch'``, and ``'condition'`` are not supported for ``AverageTFR``. Parameters ---------- %(picks_all)s %(index_df_epo)s Valid string values are ``'time'``, ``'freq'``, ``'epoch'``, and ``'condition'`` for ``EpochsTFR`` and ``'time'`` and ``'freq'`` for ``AverageTFR``. Defaults to ``None``. %(long_format_df_epo)s %(time_format_df)s .. versionadded:: 0.23 %(verbose)s Returns ------- %(df_return)s """ # check pandas once here, instead of in each private utils function pd = _check_pandas_installed() # noqa # arg checking valid_index_args = ["time", "freq"] if isinstance(self, EpochsTFR): valid_index_args.extend(["epoch", "condition"]) valid_time_formats = ["ms", "timedelta"] index = _check_pandas_index_arguments(index, valid_index_args) time_format = _check_time_format(time_format, valid_time_formats) # get data picks = _picks_to_idx(self.info, picks, "all", exclude=()) data, times, freqs = self.get_data(picks, return_times=True, return_freqs=True) axis = self._dims.index("channel") if not isinstance(self, EpochsTFR): data = data[np.newaxis] # add singleton "epochs" axis axis += 1 n_epochs, n_picks, n_freqs, n_times = data.shape # reshape to (epochs*freqs*times) x signals data = np.moveaxis(data, axis, -1) data = data.reshape(n_epochs * n_freqs * n_times, n_picks) # prepare extra columns / multiindex mindex = list() times = _convert_times(times, time_format, self.info["meas_date"]) times = np.tile(times, n_epochs * n_freqs) freqs = np.tile(np.repeat(freqs, n_times), n_epochs) mindex.append(("time", times)) mindex.append(("freq", freqs)) if isinstance(self, EpochsTFR): mindex.append(("epoch", np.repeat(self.selection, n_times * n_freqs))) rev_event_id = {v: k for k, v in self.event_id.items()} conditions = [rev_event_id[k] for k in self.events[:, 2]] mindex.append(("condition", np.repeat(conditions, n_times * n_freqs))) assert all(len(mdx) == len(mindex[0]) for mdx in mindex[1:]) # build DataFrame if isinstance(self, EpochsTFR): default_index = ["condition", "epoch", "freq", "time"] else: default_index = ["freq", "time"] df = _build_data_frame( self, data, picks, long_format, mindex, index, default_index=default_index ) return df @fill_doc class AverageTFR(BaseTFR): """Data object for spectrotemporal representations of averaged data. .. warning:: The preferred means of creating AverageTFR objects is via the instance methods :meth:`mne.Epochs.compute_tfr` and :meth:`mne.Evoked.compute_tfr`, or via :meth:`mne.time_frequency.EpochsTFR.average`. Direct class instantiation is discouraged. Parameters ---------- %(info_not_none)s .. deprecated:: 1.7 Pass an instance of :class:`~mne.Epochs` or :class:`~mne.Evoked` instead, or use :class:`~mne.time_frequency.AverageTFRArray` which retains the old API. data : ndarray, shape (n_channels, n_freqs, n_times) The data. .. deprecated:: 1.7 Pass an instance of :class:`~mne.Epochs` or :class:`~mne.Evoked` instead, or use :class:`~mne.time_frequency.AverageTFRArray` which retains the old API. times : ndarray, shape (n_times,) The time values in seconds. .. deprecated:: 1.7 Pass an instance of :class:`~mne.Epochs` or :class:`~mne.Evoked` instead and (optionally) use ``tmin`` and ``tmax`` to restrict the time domain; or use :class:`~mne.time_frequency.AverageTFRArray` which retains the old API. freqs : ndarray, shape (n_freqs,) The frequencies in Hz. nave : int The number of averaged TFRs. .. deprecated:: 1.7 Pass an instance of :class:`~mne.Epochs` or :class:`~mne.Evoked` instead; ``nave`` will be inferred automatically. Or, use :class:`~mne.time_frequency.AverageTFRArray` which retains the old API. inst : instance of Evoked | instance of Epochs | dict The data from which to compute the time-frequency representation. Passing a :class:`dict` will create the AverageTFR using the ``__setstate__`` interface and is not recommended for typical use cases. %(method_tfr)s %(freqs_tfr)s %(tmin_tmax_psd)s %(picks_good_data_noref)s %(proj_psd)s %(decim_tfr)s %(comment_averagetfr)s %(n_jobs)s %(verbose)s %(method_kw_tfr)s Attributes ---------- %(baseline_tfr_attr)s %(ch_names_tfr_attr)s %(comment_averagetfr_attr)s %(freqs_tfr_attr)s %(info_not_none)s %(method_tfr_attr)s %(nave_tfr_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s See Also -------- RawTFR EpochsTFR AverageTFRArray mne.Evoked.compute_tfr mne.time_frequency.EpochsTFR.average Notes ----- The old API (prior to version 1.7) was:: AverageTFR(info, data, times, freqs, nave, comment=None, method=None) That API is still available via :class:`~mne.time_frequency.AverageTFRArray` for cases where the data are precomputed or do not originate from MNE-Python objects. The preferred new API uses instance methods:: evoked.compute_tfr(method, freqs, ...) epochs.compute_tfr(method, freqs, average=True, ...) The new API also supports AverageTFR instantiation from a :class:`dict`, but this is primarily for save/load and internal purposes, and wraps ``__setstate__``. During the transition from the old to the new API, it may be expedient to use :class:`~mne.time_frequency.AverageTFRArray` as a "quick-fix" approach to updating scripts under active development. References ---------- .. footbibliography:: """ def __init__( self, info=None, data=None, times=None, freqs=None, nave=None, *, inst=None, method=None, tmin=None, tmax=None, picks=None, proj=False, decim=1, comment=None, n_jobs=None, verbose=None, **method_kw, ): from ..epochs import BaseEpochs from ..evoked import Evoked from ._stockwell import _check_input_st, _compute_freqs_st # dict is allowed for __setstate__ compatibility, and Epochs.compute_tfr() can # return an AverageTFR depending on its parameters, so Epochs input is allowed _validate_type( inst, (BaseEpochs, Evoked, dict), "object passed to AverageTFR constructor" ) # stockwell API is very different from multitaper/morlet if method == "stockwell" and not isinstance(inst, dict): if isinstance(freqs, str) and freqs == "auto": fmin, fmax = None, None elif len(freqs) == 2: fmin, fmax = freqs else: raise ValueError( "for Stockwell method, freqs must be a length-2 iterable " f'or "auto", got {freqs}.' ) method_kw.update(fmin=fmin, fmax=fmax) # Compute freqs. We need a couple lines of code dupe here (also in # BaseTFR.__init__) to get the subset of times to pass to _check_input_st() _mask = _time_mask(inst.times, tmin, tmax, sfreq=inst.info["sfreq"]) _times = inst.times[_mask].copy() _, default_nfft, _ = _check_input_st(_times, None) n_fft = method_kw.get("n_fft", default_nfft) *_, freqs = _compute_freqs_st(fmin, fmax, n_fft, inst.info["sfreq"]) # use Evoked.comment or str(Epochs.event_id) as the default comment... if comment is None: comment = getattr(inst, "comment", ",".join(getattr(inst, "event_id", ""))) # ...but don't overwrite if it's coming in with a comment already set if isinstance(inst, dict): inst.setdefault("comment", comment) else: self._comment = getattr(self, "_comment", comment) super().__init__( inst, method, freqs, tmin=tmin, tmax=tmax, picks=picks, proj=proj, decim=decim, n_jobs=n_jobs, verbose=verbose, **method_kw, ) def __getstate__(self): """Prepare AverageTFR object for serialization.""" out = super().__getstate__() out.update(nave=self.nave, comment=self.comment) # NOTE: self._itc should never exist in the instance returned to the user; it # is temporarily present in the output from the tfr_array_* function, and is # split out into a separate AverageTFR object (and deleted from the object # holding power estimates) before those objects are passed back to the user. # The following lines are there because we make use of __getstate__ to achieve # that splitting of objects. if hasattr(self, "_itc"): out.update(itc=self._itc) return out def __setstate__(self, state): """Unpack AverageTFR from serialized format.""" super().__setstate__(state) self._comment = state.get("comment", "") self._nave = state.get("nave", 1) @property def comment(self): return self._comment @comment.setter def comment(self, comment): self._comment = comment @property def nave(self): return self._nave @nave.setter def nave(self, nave): self._nave = nave def _get_instance_data(self, time_mask): # AverageTFRs can be constructed from Epochs data, so we triage shape here. # Evoked data get a fake singleton "epoch" axis prepended dim = slice(None) if _get_instance_type_string(self) == "Epochs" else np.newaxis data = self.inst.get_data(picks=self._picks)[dim, :, time_mask] self._nave = getattr(self.inst, "nave", data.shape[0]) return data @fill_doc class AverageTFRArray(AverageTFR): """Data object for *precomputed* spectrotemporal representations of averaged data. Parameters ---------- %(info_not_none)s %(data_tfr)s %(times)s %(freqs_tfr_array)s nave : int The number of averaged TFRs. %(comment_averagetfr_attr)s %(method_tfr_array)s Attributes ---------- %(baseline_tfr_attr)s %(ch_names_tfr_attr)s %(comment_averagetfr_attr)s %(freqs_tfr_attr)s %(info_not_none)s %(method_tfr_attr)s %(nave_tfr_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s See Also -------- AverageTFR EpochsTFRArray mne.Epochs.compute_tfr mne.Evoked.compute_tfr """ def __init__( self, info, data, times, freqs, *, nave=None, comment=None, method=None ): state = dict(info=info, data=data, times=times, freqs=freqs) for name, optional in dict(nave=nave, comment=comment, method=method).items(): if optional is not None: state[name] = optional self.__setstate__(state) @fill_doc class EpochsTFR(BaseTFR, GetEpochsMixin): """Data object for spectrotemporal representations of epoched data. .. important:: The preferred means of creating EpochsTFR objects from :class:`~mne.Epochs` objects is via the instance method :meth:`~mne.Epochs.compute_tfr`. To create an EpochsTFR object from pre-computed data (i.e., a NumPy array) use :class:`~mne.time_frequency.EpochsTFRArray`. Parameters ---------- %(info_not_none)s .. deprecated:: 1.7 Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. data : ndarray, shape (n_channels, n_freqs, n_times) The data. .. deprecated:: 1.7 Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. times : ndarray, shape (n_times,) The time values in seconds. .. deprecated:: 1.7 Pass an instance of :class:`~mne.Epochs` as ``inst`` instead and (optionally) use ``tmin`` and ``tmax`` to restrict the time domain; or use :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. %(freqs_tfr_epochs)s inst : instance of Epochs The data from which to compute the time-frequency representation. %(method_tfr_epochs)s %(comment_tfr_attr)s .. deprecated:: 1.7 Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. %(tmin_tmax_psd)s %(picks_good_data_noref)s %(proj_psd)s %(decim_tfr)s %(events_epochstfr)s .. deprecated:: 1.7 Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. %(event_id_epochstfr)s .. deprecated:: 1.7 Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. selection : array List of indices of selected events (not dropped or ignored etc.). For example, if the original event array had 4 events and the second event has been dropped, this attribute would be np.array([0, 2, 3]). .. deprecated:: 1.7 Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. drop_log : tuple of tuple A tuple of the same length as the event array used to initialize the ``EpochsTFR`` object. If the i-th original event is still part of the selection, drop_log[i] will be an empty tuple; otherwise it will be a tuple of the reasons the event is not longer in the selection, e.g.: - ``'IGNORED'`` If it isn't part of the current subset defined by the user - ``'NO_DATA'`` or ``'TOO_SHORT'`` If epoch didn't contain enough data names of channels that exceeded the amplitude threshold - ``'EQUALIZED_COUNTS'`` See :meth:`~mne.Epochs.equalize_event_counts` - ``'USER'`` For user-defined reasons (see :meth:`~mne.Epochs.drop`). .. deprecated:: 1.7 Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. %(metadata_epochstfr)s .. deprecated:: 1.7 Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. %(n_jobs)s %(verbose)s %(method_kw_tfr)s Attributes ---------- %(baseline_tfr_attr)s %(ch_names_tfr_attr)s %(comment_tfr_attr)s %(drop_log)s %(event_id_attr)s %(events_attr)s %(freqs_tfr_attr)s %(info_not_none)s %(metadata_attr)s %(method_tfr_attr)s %(selection_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s See Also -------- mne.Epochs.compute_tfr RawTFR AverageTFR EpochsTFRArray References ---------- .. footbibliography:: """ def __init__( self, info=None, data=None, times=None, freqs=None, *, inst=None, method=None, comment=None, tmin=None, tmax=None, picks=None, proj=False, decim=1, events=None, event_id=None, selection=None, drop_log=None, metadata=None, n_jobs=None, verbose=None, **method_kw, ): from ..epochs import BaseEpochs # deprecations. TODO remove after 1.7 release depr_params = dict(info=info, data=data, times=times, comment=comment) bad_params = list() for name, param in depr_params.items(): if param is not None: bad_params.append(name) if len(bad_params): _s = _pl(bad_params) is_are = _pl(bad_params, "is", "are") bad_params_list = '", "'.join(bad_params) warn( f'Parameter{_s} "{bad_params_list}" {is_are} deprecated and will be ' "removed in version 1.8. For a quick fix, use ``EpochsTFRArray`` with " "the same parameters. For a long-term fix, see the docstring notes.", FutureWarning, ) if inst is not None: raise ValueError( "Do not pass `inst` alongside deprecated params " f'"{bad_params_list}"; see docstring of AverageTFR for guidance.' ) # sensible defaults are created in __setstate__ so only pass these through # if they're user-specified optional = dict( freqs=freqs, method=method, events=events, event_id=event_id, selection=selection, drop_log=drop_log, metadata=metadata, ) optional_params = { key: val for key, val in optional.items() if val is not None } inst = depr_params | optional_params # end TODO ↑↑↑↑↑↑ # dict is allowed for __setstate__ compatibility _validate_type( inst, (BaseEpochs, dict), "object passed to EpochsTFR constructor", "Epochs" ) super().__init__( inst, method, freqs, tmin=tmin, tmax=tmax, picks=picks, proj=proj, decim=decim, n_jobs=n_jobs, verbose=verbose, **method_kw, ) @fill_doc def __getitem__(self, item): """Subselect epochs from an EpochsTFR. Parameters ---------- %(item)s Access options are the same as for :class:`~mne.Epochs` objects, see the docstring Notes section of :meth:`mne.Epochs.__getitem__` for explanation. Returns ------- %(getitem_epochstfr_return)s """ return super().__getitem__(item) def __getstate__(self): """Prepare EpochsTFR object for serialization.""" out = super().__getstate__() out.update( metadata=self._metadata, drop_log=self.drop_log, event_id=self.event_id, events=self.events, selection=self.selection, raw_times=self._raw_times, ) return out def __setstate__(self, state): """Unpack EpochsTFR from serialized format.""" if state["data"].ndim != 4: raise ValueError(f"EpochsTFR data should be 4D, got {state['data'].ndim}.") super().__setstate__(state) self._metadata = state.get("metadata", None) n_epochs = self.shape[0] n_times = self.shape[-1] fake_samps = np.linspace( n_times, n_times * (n_epochs + 1), n_epochs, dtype=int, endpoint=False ) fake_events = np.dstack( (fake_samps, np.zeros_like(fake_samps), np.ones_like(fake_samps)) ).squeeze(axis=0) self.events = state.get("events", _ensure_events(fake_events)) self.event_id = state.get("event_id", _check_event_id(None, self.events)) self.drop_log = state.get("drop_log", tuple()) self.selection = state.get("selection", np.arange(n_epochs)) self._bad_dropped = True # always true, need for `equalize_event_counts()` def __next__(self, return_event_id=False): """Iterate over EpochsTFR objects. NOTE: __iter__() and _stop_iter() are defined by the GetEpochs mixin. Parameters ---------- return_event_id : bool If ``True``, return both the EpochsTFR data and its associated ``event_id``. Returns ------- epoch : array of shape (n_channels, n_freqs, n_times) The single-epoch time-frequency data. event_id : int The integer event id associated with the epoch. Only returned if ``return_event_id`` is ``True``. """ if self._current >= len(self._data): self._stop_iter() epoch = self._data[self._current] event_id = self.events[self._current][-1] self._current += 1 if return_event_id: return epoch, event_id return epoch def _check_singleton(self): """Check if self contains only one Epoch, and return it as an AverageTFR.""" if self.shape[0] > 1: calling_func = inspect.currentframe().f_back.f_code.co_name raise NotImplementedError( f"Cannot call {calling_func}() from EpochsTFR with multiple epochs; " "please subselect a single epoch before plotting." ) return list(self.iter_evoked())[0] def _get_instance_data(self, time_mask): return self.inst.get_data(picks=self._picks)[:, :, time_mask] def _update_epoch_attributes(self): # adjust dims and shape if self.method != "stockwell": # stockwell consumes epochs dimension self._dims = ("epoch",) + self._dims self._shape = (len(self.inst),) + self._shape # we need these for to_data_frame() self.event_id = self.inst.event_id.copy() self.events = self.inst.events.copy() self.selection = self.inst.selection.copy() # we need these for __getitem__() self.drop_log = deepcopy(self.inst.drop_log) self._metadata = self.inst.metadata # we need this for compatibility with equalize_event_counts() self._bad_dropped = True def average(self, method="mean", *, dim="epochs", copy=False): """Aggregate the EpochsTFR across epochs, frequencies, or times. Parameters ---------- method : "mean" | "median" | callable How to aggregate the data across the given ``dim``. If callable, must take a :class:`NumPy array` of shape ``(n_epochs, n_channels, n_freqs, n_times)`` and return an array with one fewer dimensions (which dimension is collapsed depends on the value of ``dim``). Default is ``"mean"``. dim : "epochs" | "freqs" | "times" The dimension along which to combine the data. copy : bool Whether to return a copy of the modified instance, or modify in place. Ignored when ``dim="epochs"`` or ``"times"`` because those options return different types (:class:`~mne.time_frequency.AverageTFR` and :class:`~mne.time_frequency.EpochsSpectrum`, respectively). Returns ------- tfr : instance of EpochsTFR | AverageTFR | EpochsSpectrum The aggregated TFR object. Notes ----- Passing in ``np.median`` is considered unsafe for complex data; pass the string ``"median"`` instead to compute the *marginal* median (i.e. the median of the real and imaginary components separately). See discussion here: https://github.com/scipy/scipy/pull/12676#issuecomment-783370228 """ _check_option("dim", dim, ("epochs", "freqs", "times")) axis = self._dims.index(dim[:-1]) # self._dims entries aren't plural func = _check_combine(mode=method, axis=axis) data = func(self.data) n_epochs, n_channels, n_freqs, n_times = self.data.shape freqs, times = self.freqs, self.times if dim == "epochs": expected_shape = self._data.shape[1:] elif dim == "freqs": expected_shape = (n_epochs, n_channels, n_times) freqs = np.mean(self.freqs, keepdims=True) elif dim == "times": expected_shape = (n_epochs, n_channels, n_freqs) times = np.mean(self.times, keepdims=True) if data.shape != expected_shape: raise RuntimeError( "EpochsTFR.average() got a method that resulted in data of shape " f"{data.shape}, but it should be {expected_shape}." ) state = self.__getstate__() # restore singleton freqs axis (not necessary for epochs/times: class changes) if dim == "freqs": data = np.expand_dims(data, axis=axis) else: state["dims"] = (*state["dims"][:axis], *state["dims"][axis + 1 :]) state["data"] = data state["info"] = deepcopy(self.info) state["freqs"] = freqs state["times"] = times if dim == "epochs": state["inst_type_str"] = "Evoked" state["nave"] = n_epochs state["comment"] = f"{method} of {n_epochs} EpochsTFR{_pl(n_epochs)}" out = AverageTFR(inst=state) out._data_type = "Average Power" return out elif dim == "times": return EpochsSpectrum( state, method=None, fmin=None, fmax=None, tmin=None, tmax=None, picks=None, exclude=None, proj=None, remove_dc=None, n_jobs=None, ) # ↓↓↓ these two are for dim == "freqs" elif copy: return EpochsTFR(inst=state, method=None, freqs=None) else: self._data = np.expand_dims(data, axis=axis) self._freqs = freqs return self @verbose def drop(self, indices, reason="USER", verbose=None): """Drop epochs based on indices or boolean mask. .. note:: The indices refer to the current set of undropped epochs rather than the complete set of dropped and undropped epochs. They are therefore not necessarily consistent with any external indices (e.g., behavioral logs). To drop epochs based on external criteria, do not use the ``preload=True`` flag when constructing an Epochs object, and call this method before calling the :meth:`mne.Epochs.drop_bad` or :meth:`mne.Epochs.load_data` methods. Parameters ---------- indices : array of int or bool Set epochs to remove by specifying indices to remove or a boolean mask to apply (where True values get removed). Events are correspondingly modified. reason : str Reason for dropping the epochs ('ECG', 'timeout', 'blink' etc). Default: 'USER'. %(verbose)s Returns ------- epochs : instance of Epochs or EpochsTFR The epochs with indices dropped. Operates in-place. """ from ..epochs import BaseEpochs BaseEpochs.drop(self, indices=indices, reason=reason, verbose=verbose) return self def iter_evoked(self, copy=False): """Iterate over EpochsTFR to yield a sequence of AverageTFR objects. The AverageTFR objects will each contain a single epoch (i.e., no averaging is performed). This method resets the EpochTFR instance's iteration state to the first epoch. Parameters ---------- copy : bool Whether to yield copies of the data and measurement info, or views/pointers. """ self.__iter__() state = self.__getstate__() state["inst_type_str"] = "Evoked" state["dims"] = state["dims"][1:] # drop "epochs" while True: try: data, event_id = self.__next__(return_event_id=True) except StopIteration: break if copy: state["info"] = deepcopy(self.info) state["data"] = data.copy() else: state["data"] = data state["nave"] = 1 yield AverageTFR(inst=state, method=None, freqs=None, comment=str(event_id)) @verbose @copy_doc(BaseTFR.plot) def plot( self, picks=None, *, exclude=(), tmin=None, tmax=None, fmin=None, fmax=None, baseline=None, mode="mean", dB=False, combine=None, layout=None, # TODO deprecate; not used in orig implementation yscale="auto", vmin=None, vmax=None, vlim=(None, None), cnorm=None, cmap=None, colorbar=True, title=None, # don't deprecate this one; has (useful) option title="auto" mask=None, mask_style=None, mask_cmap="Greys", mask_alpha=0.1, axes=None, show=True, verbose=None, ): singleton_epoch = self._check_singleton() return singleton_epoch.plot( picks=picks, exclude=exclude, tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, baseline=baseline, mode=mode, dB=dB, combine=combine, layout=layout, yscale=yscale, vmin=vmin, vmax=vmax, vlim=vlim, cnorm=cnorm, cmap=cmap, colorbar=colorbar, title=title, mask=mask, mask_style=mask_style, mask_cmap=mask_cmap, mask_alpha=mask_alpha, axes=axes, show=show, verbose=verbose, ) @verbose @copy_doc(BaseTFR.plot_topo) def plot_topo( self, picks=None, baseline=None, mode="mean", tmin=None, tmax=None, fmin=None, fmax=None, vmin=None, # TODO deprecate in favor of `vlim` (needs helper func refactor) vmax=None, layout=None, cmap=None, title=None, # don't deprecate; topo titles aren't standard (color, size, just.) dB=False, colorbar=True, layout_scale=0.945, show=True, border="none", fig_facecolor="k", fig_background=None, font_color="w", yscale="auto", verbose=None, ): singleton_epoch = self._check_singleton() return singleton_epoch.plot_topo( picks=picks, baseline=baseline, mode=mode, tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, vmin=vmin, vmax=vmax, layout=layout, cmap=cmap, title=title, dB=dB, colorbar=colorbar, layout_scale=layout_scale, show=show, border=border, fig_facecolor=fig_facecolor, fig_background=fig_background, font_color=font_color, yscale=yscale, verbose=verbose, ) @verbose @copy_doc(BaseTFR.plot_joint) def plot_joint( self, *, timefreqs=None, picks=None, exclude=(), combine="mean", tmin=None, tmax=None, fmin=None, fmax=None, baseline=None, mode="mean", dB=False, yscale="auto", vmin=None, vmax=None, vlim=(None, None), cnorm=None, cmap=None, colorbar=True, title=None, show=True, topomap_args=None, image_args=None, verbose=None, ): singleton_epoch = self._check_singleton() return singleton_epoch.plot_joint( timefreqs=timefreqs, picks=picks, exclude=exclude, combine=combine, tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, baseline=baseline, mode=mode, dB=dB, yscale=yscale, vmin=vmin, vmax=vmax, vlim=vlim, cnorm=cnorm, cmap=cmap, colorbar=colorbar, title=title, show=show, topomap_args=topomap_args, image_args=image_args, verbose=verbose, ) @copy_doc(BaseTFR.plot_topomap) def plot_topomap( self, tmin=None, tmax=None, fmin=0.0, fmax=np.inf, *, ch_type=None, baseline=None, mode="mean", sensors=True, show_names=False, mask=None, mask_params=None, contours=6, outlines="head", sphere=None, image_interp=_INTERPOLATION_DEFAULT, extrapolate=_EXTRAPOLATE_DEFAULT, border=_BORDER_DEFAULT, res=64, size=2, cmap=None, vlim=(None, None), cnorm=None, colorbar=True, cbar_fmt="%1.1e", units=None, axes=None, show=True, ): singleton_epoch = self._check_singleton() return singleton_epoch.plot_topomap( tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, ch_type=ch_type, baseline=baseline, mode=mode, sensors=sensors, show_names=show_names, mask=mask, mask_params=mask_params, contours=contours, outlines=outlines, sphere=sphere, image_interp=image_interp, extrapolate=extrapolate, border=border, res=res, size=size, cmap=cmap, vlim=vlim, cnorm=cnorm, colorbar=colorbar, cbar_fmt=cbar_fmt, units=units, axes=axes, show=show, ) @fill_doc class EpochsTFRArray(EpochsTFR): """Data object for *precomputed* spectrotemporal representations of epoched data. Parameters ---------- %(info_not_none)s %(data_tfr)s %(times)s %(freqs_tfr_array)s %(comment_tfr_attr)s %(method_tfr_array)s %(events_epochstfr)s %(event_id_epochstfr)s %(selection)s %(drop_log)s %(metadata_epochstfr)s Attributes ---------- %(baseline_tfr_attr)s %(ch_names_tfr_attr)s %(comment_tfr_attr)s %(drop_log)s %(event_id_attr)s %(events_attr)s %(freqs_tfr_attr)s %(info_not_none)s %(metadata_attr)s %(method_tfr_attr)s %(selection_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s See Also -------- AverageTFR mne.Epochs.compute_tfr mne.Evoked.compute_tfr """ def __init__( self, info, data, times, freqs, *, comment=None, method=None, events=None, event_id=None, selection=None, drop_log=None, metadata=None, ): state = dict(info=info, data=data, times=times, freqs=freqs) optional = dict( comment=comment, method=method, events=events, event_id=event_id, selection=selection, drop_log=drop_log, metadata=metadata, ) for name, value in optional.items(): if value is not None: state[name] = value self.__setstate__(state) @fill_doc class RawTFR(BaseTFR): """Data object for spectrotemporal representations of continuous data. .. warning:: The preferred means of creating RawTFR objects from :class:`~mne.io.Raw` objects is via the instance method :meth:`~mne.io.Raw.compute_tfr`. Direct class instantiation is not supported. Parameters ---------- inst : instance of Raw The data from which to compute the time-frequency representation. %(method_tfr)s %(freqs_tfr)s %(tmin_tmax_psd)s %(picks_good_data_noref)s %(proj_psd)s %(reject_by_annotation_tfr)s %(decim_tfr)s %(n_jobs)s %(verbose)s %(method_kw_tfr)s Attributes ---------- ch_names : list The channel names. freqs : array Frequencies at which the amplitude, power, or fourier coefficients have been computed. %(info_not_none)s method : str The method used to compute the spectra (``'morlet'``, ``'multitaper'`` or ``'stockwell'``). See Also -------- mne.io.Raw.compute_tfr EpochsTFR AverageTFR References ---------- .. footbibliography:: """ def __init__( self, inst, method=None, freqs=None, *, tmin=None, tmax=None, picks=None, proj=False, reject_by_annotation=False, decim=1, n_jobs=None, verbose=None, **method_kw, ): from ..io import BaseRaw # dict is allowed for __setstate__ compatibility _validate_type( inst, (BaseRaw, dict), "object passed to RawTFR constructor", "Raw" ) super().__init__( inst, method, freqs, tmin=tmin, tmax=tmax, picks=picks, proj=proj, reject_by_annotation=reject_by_annotation, decim=decim, n_jobs=n_jobs, verbose=verbose, **method_kw, ) def __getitem__(self, item): """Get RawTFR data. Parameters ---------- item : int | slice | array-like Indexing is similar to a :class:`NumPy array`; see Notes. Returns ------- %(getitem_tfr_return)s Notes ----- The last axis is always time, the next-to-last axis is always frequency, and the first axis is always channel. If ``method='multitaper'`` and ``output='complex'`` then the second axis will be taper index. Integer-, list-, and slice-based indexing is possible: - ``raw_tfr[[0, 2]]`` gives the whole time-frequency plane for the first and third channels. - ``raw_tfr[..., :3, :]`` gives the first 3 frequency bins and all times for all channels (and tapers, if present). - ``raw_tfr[..., :100]`` gives the first 100 time samples in all frequency bins for all channels (and tapers). - ``raw_tfr[(4, 7)]`` is the same as ``raw_tfr[4, 7]``. .. note:: Unlike :class:`~mne.io.Raw` objects (which returns a tuple of the requested data values and the corresponding times), accessing :class:`~mne.time_frequency.RawTFR` values via subscript does **not** return the corresponding frequency bin values. If you need them, use ``RawTFR.freqs[freq_indices]`` or ``RawTFR.get_data(..., return_freqs=True)``. """ from ..io import BaseRaw self._parse_get_set_params = partial(BaseRaw._parse_get_set_params, self) return BaseRaw._getitem(self, item, return_times=False) def _get_instance_data(self, time_mask, reject_by_annotation): start, stop = np.where(time_mask)[0][[0, -1]] rba = "NaN" if reject_by_annotation else None data = self.inst.get_data( self._picks, start, stop + 1, reject_by_annotation=rba ) # prepend a singleton "epochs" axis return data[np.newaxis] @fill_doc class RawTFRArray(RawTFR): """Data object for *precomputed* spectrotemporal representations of continuous data. Parameters ---------- %(info_not_none)s %(data_tfr)s %(times)s %(freqs_tfr_array)s %(method_tfr_array)s Attributes ---------- %(baseline_tfr_attr)s %(ch_names_tfr_attr)s %(freqs_tfr_attr)s %(info_not_none)s %(method_tfr_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s See Also -------- RawTFR mne.io.Raw.compute_tfr EpochsTFRArray AverageTFRArray """ def __init__( self, info, data, times, freqs, *, method=None, ): state = dict(info=info, data=data, times=times, freqs=freqs) if method is not None: state["method"] = method self.__setstate__(state) def combine_tfr(all_tfr, weights="nave"): """Merge AverageTFR data by weighted addition. Create a new AverageTFR instance, using a combination of the supplied instances as its data. By default, the mean (weighted by trials) is used. Subtraction can be performed by passing negative weights (e.g., [1, -1]). Data must have the same channels and the same time instants. Parameters ---------- all_tfr : list of AverageTFR The tfr datasets. weights : list of float | str The weights to apply to the data of each AverageTFR instance. Can also be ``'nave'`` to weight according to tfr.nave, or ``'equal'`` to use equal weighting (each weighted as ``1/N``). Returns ------- tfr : AverageTFR The new TFR data. Notes ----- .. versionadded:: 0.11.0 """ tfr = all_tfr[0].copy() if isinstance(weights, str): if weights not in ("nave", "equal"): raise ValueError('Weights must be a list of float, or "nave" or "equal"') if weights == "nave": weights = np.array([e.nave for e in all_tfr], float) weights /= weights.sum() else: # == 'equal' weights = [1.0 / len(all_tfr)] * len(all_tfr) weights = np.array(weights, float) if weights.ndim != 1 or weights.size != len(all_tfr): raise ValueError("Weights must be the same size as all_tfr") ch_names = tfr.ch_names for t_ in all_tfr[1:]: assert t_.ch_names == ch_names, ValueError( f"{tfr} and {t_} do not contain the same channels" ) assert np.max(np.abs(t_.times - tfr.times)) < 1e-7, ValueError( f"{tfr} and {t_} do not contain the same time instants" ) # use union of bad channels bads = list(set(tfr.info["bads"]).union(*(t_.info["bads"] for t_ in all_tfr[1:]))) tfr.info["bads"] = bads # XXX : should be refactored with combined_evoked function tfr.data = sum(w * t_.data for w, t_ in zip(weights, all_tfr)) tfr.nave = max(int(1.0 / sum(w**2 / e.nave for w, e in zip(weights, all_tfr))), 1) return tfr # Utils # ↓↓↓↓↓↓↓↓↓↓↓ this is still used in _stockwell.py def _get_data(inst, return_itc): """Get data from Epochs or Evoked instance as epochs x ch x time.""" from ..epochs import BaseEpochs from ..evoked import Evoked if not isinstance(inst, (BaseEpochs, Evoked)): raise TypeError("inst must be Epochs or Evoked") if isinstance(inst, BaseEpochs): data = inst.get_data(copy=False) else: if return_itc: raise ValueError("return_itc must be False for evoked data") data = inst.data[np.newaxis].copy() return data def _prepare_picks(info, data, picks, axis): """Prepare the picks.""" picks = _picks_to_idx(info, picks, exclude="bads") info = pick_info(info, picks) sl = [slice(None)] * data.ndim sl[axis] = picks data = data[tuple(sl)] return info, data def _centered(arr, newsize): """Aux Function to center data.""" # Return the center newsize portion of the array. newsize = np.asarray(newsize) currsize = np.array(arr.shape) startind = (currsize - newsize) // 2 endind = startind + newsize myslice = [slice(startind[k], endind[k]) for k in range(len(endind))] return arr[tuple(myslice)] def _preproc_tfr( data, times, freqs, tmin, tmax, fmin, fmax, mode, baseline, vmin, vmax, dB, sfreq, copy=None, ): """Aux Function to prepare tfr computation.""" if copy is None: copy = baseline is not None data = rescale(data, times, baseline, mode, copy=copy) if np.iscomplexobj(data): # complex amplitude → real power (for plotting); if data are # real-valued they should already be power data = (data * data.conj()).real # crop time itmin, itmax = None, None idx = np.where(_time_mask(times, tmin, tmax, sfreq=sfreq))[0] if tmin is not None: itmin = idx[0] if tmax is not None: itmax = idx[-1] + 1 times = times[itmin:itmax] # crop freqs ifmin, ifmax = None, None idx = np.where(_time_mask(freqs, fmin, fmax, sfreq=sfreq))[0] if fmin is not None: ifmin = idx[0] if fmax is not None: ifmax = idx[-1] + 1 freqs = freqs[ifmin:ifmax] # crop data data = data[:, ifmin:ifmax, itmin:itmax] if dB: data = 10 * np.log10(data) vmin, vmax = _setup_vmin_vmax(data, vmin, vmax) return data, times, freqs, vmin, vmax def _ensure_slice(decim): """Aux function checking the decim parameter.""" _validate_type(decim, ("int-like", slice), "decim") if not isinstance(decim, slice): decim = slice(None, None, int(decim)) # ensure that we can actually use `decim.step` if decim.step is None: decim = slice(decim.start, decim.stop, 1) return decim # i/o @verbose def write_tfrs(fname, tfr, overwrite=False, *, verbose=None): """Write a TFR dataset to hdf5. Parameters ---------- fname : path-like The file name, which should end with ``-tfr.h5``. tfr : RawTFR | EpochsTFR | AverageTFR | list of RawTFR | list of EpochsTFR | list of AverageTFR The (list of) TFR object(s) to save in one file. If ``tfr.comment`` is ``None``, a sequential numeric string name will be generated on the fly, based on the order in which the TFR objects are passed. This can be used to selectively load single TFR objects from the file later. %(overwrite)s %(verbose)s See Also -------- read_tfrs Notes ----- .. versionadded:: 0.9.0 """ # noqa E501 _, write_hdf5 = _import_h5io_funcs() out = [] if not isinstance(tfr, (list, tuple)): tfr = [tfr] for ii, tfr_ in enumerate(tfr): comment = ii if getattr(tfr_, "comment", None) is None else tfr_.comment state = tfr_.__getstate__() if "metadata" in state: state["metadata"] = _prepare_write_metadata(state["metadata"]) out.append((comment, state)) write_hdf5(fname, out, overwrite=overwrite, title="mnepython", slash="replace") @verbose def read_tfrs(fname, condition=None, *, verbose=None): """Load a TFR object from disk. Parameters ---------- fname : path-like Path to a TFR file in HDF5 format, which should end with ``-tfr.h5`` or ``-tfr.hdf5``. condition : int or str | list of int or str | None The condition to load. If ``None``, all conditions will be returned. Defaults to ``None``. %(verbose)s Returns ------- tfr : RawTFR | EpochsTFR | AverageTFR | list of RawTFR | list of EpochsTFR | list of AverageTFR The loaded time-frequency object. See Also -------- mne.time_frequency.RawTFR.save mne.time_frequency.EpochsTFR.save mne.time_frequency.AverageTFR.save write_tfrs Notes ----- .. versionadded:: 0.9.0 """ # noqa E501 read_hdf5, _ = _import_h5io_funcs() fname = _check_fname(fname=fname, overwrite="read", must_exist=False) valid_fnames = tuple( f"{sep}tfr.{ext}" for sep in ("-", "_") for ext in ("h5", "hdf5") ) check_fname(fname, "tfr", valid_fnames) logger.info(f"Reading {fname} ...") hdf5_dict = read_hdf5(fname, title="mnepython", slash="replace") # single TFR from TFR.save() if "inst_type_str" in hdf5_dict: if "epoch" in hdf5_dict["dims"]: Klass = EpochsTFR elif "nave" in hdf5_dict: Klass = AverageTFR else: Klass = RawTFR out = Klass(inst=hdf5_dict) if getattr(out, "metadata", None) is not None: out.metadata = _prepare_read_metadata(out.metadata) return out # maybe multiple TFRs from write_tfrs() return _read_multiple_tfrs(hdf5_dict, condition=condition, verbose=verbose) @verbose def _read_multiple_tfrs(tfr_data, condition=None, *, verbose=None): """Read (possibly multiple) TFR datasets from an h5 file written by write_tfrs().""" out = list() keys = list() # tfr_data is a list of (comment, tfr_dict) tuples for key, tfr in tfr_data: keys.append(str(key)) # auto-assigned keys are ints is_epochs = tfr["data"].ndim == 4 is_average = "nave" in tfr if condition is not None: if not is_average: raise NotImplementedError( "condition is only supported when reading AverageTFRs." ) if key != condition: continue tfr = dict(tfr) tfr["info"] = Info(tfr["info"]) tfr["info"]._check_consistency() if "metadata" in tfr: tfr["metadata"] = _prepare_read_metadata(tfr["metadata"]) # additional keys needed for TFR __setstate__ defaults = dict(baseline=None, data_type="Power Estimates") if is_epochs: Klass = EpochsTFR defaults.update( inst_type_str="Epochs", dims=("epoch", "channel", "freq", "time") ) elif is_average: Klass = AverageTFR defaults.update(inst_type_str="Evoked", dims=("channel", "freq", "time")) else: Klass = RawTFR defaults.update(inst_type_str="Raw", dims=("channel", "freq", "time")) out.append(Klass(inst=defaults | tfr)) if len(out) == 0: raise ValueError( f'Cannot find condition "{condition}" in this file. ' f'The file contains conditions {", ".join(keys)}' ) if len(out) == 1: out = out[0] return out def _get_timefreqs(tfr, timefreqs): """Find and/or setup timefreqs for `tfr.plot_joint`.""" # Input check timefreq_error_msg = ( "Supplied `timefreqs` are somehow malformed. Please supply None, " "a list of tuple pairs, or a dict of such tuple pairs, not {}" ) if isinstance(timefreqs, dict): for k, v in timefreqs.items(): for item in (k, v): if len(item) != 2 or any(not _is_numeric(n) for n in item): raise ValueError(timefreq_error_msg, item) elif timefreqs is not None: if not hasattr(timefreqs, "__len__"): raise ValueError(timefreq_error_msg.format(timefreqs)) if len(timefreqs) == 2 and all(_is_numeric(v) for v in timefreqs): timefreqs = [tuple(timefreqs)] # stick a pair of numbers in a list else: for item in timefreqs: if ( hasattr(item, "__len__") and len(item) == 2 and all(_is_numeric(n) for n in item) ): pass else: raise ValueError(timefreq_error_msg.format(item)) # If None, automatic identification of max peak else: order = max((1, tfr.data.shape[2] // 30)) peaks_idx = argrelmax(tfr.data, order=order, axis=2) if peaks_idx[0].size == 0: _, p_t, p_f = np.unravel_index(tfr.data.argmax(), tfr.data.shape) timefreqs = [(tfr.times[p_t], tfr.freqs[p_f])] else: peaks = [tfr.data[0, f, t] for f, t in zip(peaks_idx[1], peaks_idx[2])] peakmax_idx = np.argmax(peaks) peakmax_time = tfr.times[peaks_idx[2][peakmax_idx]] peakmax_freq = tfr.freqs[peaks_idx[1][peakmax_idx]] timefreqs = [(peakmax_time, peakmax_freq)] timefreqs = { tuple(k): np.asarray(timefreqs[k]) if isinstance(timefreqs, dict) else np.array([0, 0]) for k in timefreqs } return timefreqs def _check_tfr_complex(tfr, reason="source space estimation"): """Check that time-frequency epochs or average data is complex.""" if not np.iscomplexobj(tfr.data): raise RuntimeError(f"Time-frequency data must be complex for {reason}") def _merge_if_grads(data, info, ch_type, sphere, combine=None): if ch_type == "grad": grad_picks = _pair_grad_sensors(info, topomap_coords=False) pos = _find_topomap_coords(info, picks=grad_picks[::2], sphere=sphere) grad_method = combine if isinstance(combine, str) else "rms" data, _ = _merge_ch_data(data[grad_picks], ch_type, [], method=grad_method) else: pos, _ = _get_pos_outlines(info, picks=ch_type, sphere=sphere) return data, pos @verbose def _prep_data_for_plot( data, times, freqs, *, tmin=None, tmax=None, fmin=None, fmax=None, baseline=None, mode=None, dB=False, verbose=None, ): # baseline copy = baseline is not None data = rescale(data, times, baseline, mode, copy=copy, verbose=verbose) # crop times time_mask = np.nonzero(_time_mask(times, tmin, tmax))[0] times = times[time_mask] # crop freqs freq_mask = np.nonzero(_time_mask(freqs, fmin, fmax))[0] freqs = freqs[freq_mask] # crop data data = data[..., freq_mask, :][..., time_mask] # complex amplitude → real power; real-valued data is already power (or ITC) if np.iscomplexobj(data): data = (data * data.conj()).real if dB: data = 10 * np.log10(data) return data, times, freqs def _warn_deprecated_vmin_vmax(vlim, vmin, vmax): if vmin is not None or vmax is not None: warning = "Parameters `vmin` and `vmax` are deprecated, use `vlim` instead." if vlim[0] is None and vlim[1] is None: vlim = (vmin, vmax) else: warning += ( " You've also provided a (non-default) value for `vlim`, " "so `vmin` and `vmax` will be ignored." ) warn(warning, FutureWarning) return vlim