Files
Feature-Extraction/dist/client/mne/time_frequency/tfr.py

4342 lines
141 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""A module which implements the time-frequency estimation.
Morlet code inspired by Matlab code from Sheraz Khan & Brainstorm & SPM
"""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import inspect
from copy import deepcopy
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
from scipy.fft import fft, ifft
from scipy.signal import argrelmax
from .._fiff.meas_info import ContainsMixin, Info
from .._fiff.pick import _picks_to_idx, pick_info
from ..baseline import _check_baseline, rescale
from ..channels.channels import UpdateChannelsMixin
from ..channels.layout import _find_topomap_coords, _merge_ch_data, _pair_grad_sensors
from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT
from ..filter import next_fast_len
from ..parallel import parallel_func
from ..utils import (
ExtendedTimeMixin,
GetEpochsMixin,
SizeMixin,
_build_data_frame,
_check_combine,
_check_event_id,
_check_fname,
_check_method_kwargs,
_check_option,
_check_pandas_index_arguments,
_check_pandas_installed,
_check_time_format,
_convert_times,
_ensure_events,
_freq_mask,
_import_h5io_funcs,
_is_numeric,
_pl,
_prepare_read_metadata,
_prepare_write_metadata,
_time_mask,
_validate_type,
check_fname,
copy_doc,
copy_function_doc_to_method_doc,
fill_doc,
legacy,
logger,
object_diff,
repr_html,
sizeof_fmt,
verbose,
warn,
)
from ..utils.spectrum import _get_instance_type_string
from ..viz.topo import _imshow_tfr, _imshow_tfr_unified, _plot_topo
from ..viz.topomap import (
_add_colorbar,
_get_pos_outlines,
_set_contour_locator,
plot_tfr_topomap,
plot_topomap,
)
from ..viz.utils import (
_make_combine_callable,
_prepare_joint_axes,
_set_title_multiple_electrodes,
_setup_cmap,
_setup_vmin_vmax,
add_background_image,
figure_nobar,
plt_show,
)
from .multitaper import dpss_windows, tfr_array_multitaper
from .spectrum import EpochsSpectrum
@fill_doc
def morlet(sfreq, freqs, n_cycles=7.0, sigma=None, zero_mean=False):
"""Compute Morlet wavelets for the given frequency range.
Parameters
----------
sfreq : float
The sampling Frequency.
freqs : float | array-like, shape (n_freqs,)
Frequencies to compute Morlet wavelets for.
n_cycles : float | array-like, shape (n_freqs,)
Number of cycles. Can be a fixed number (float) or one per frequency
(array-like).
sigma : float, default None
It controls the width of the wavelet ie its temporal
resolution. If sigma is None the temporal resolution
is adapted with the frequency like for all wavelet transform.
The higher the frequency the shorter is the wavelet.
If sigma is fixed the temporal resolution is fixed
like for the short time Fourier transform and the number
of oscillations increases with the frequency.
zero_mean : bool, default False
Make sure the wavelet has a mean of zero.
Returns
-------
Ws : list of ndarray | ndarray
The wavelets time series. If ``freqs`` was a float, a single
ndarray is returned instead of a list of ndarray.
See Also
--------
mne.time_frequency.fwhm
Notes
-----
%(morlet_reference)s
%(fwhm_morlet_notes)s
References
----------
.. footbibliography::
Examples
--------
Let's show a simple example of the relationship between ``n_cycles`` and
the FWHM using :func:`mne.time_frequency.fwhm`:
.. plot::
import numpy as np
import matplotlib.pyplot as plt
from mne.time_frequency import morlet, fwhm
sfreq, freq, n_cycles = 1000., 10, 7 # i.e., 700 ms
this_fwhm = fwhm(freq, n_cycles)
wavelet = morlet(sfreq=sfreq, freqs=freq, n_cycles=n_cycles)
M, w = len(wavelet), n_cycles # convert to SciPy convention
s = w * sfreq / (2 * freq * np.pi) # from SciPy docs
_, ax = plt.subplots(layout="constrained")
colors = dict(real="#66CCEE", imag="#EE6677")
t = np.arange(-M // 2 + 1, M // 2 + 1) / sfreq
for kind in ('real', 'imag'):
ax.plot(
t, getattr(wavelet, kind), label=kind, color=colors[kind],
)
ax.plot(t, np.abs(wavelet), label=f'abs', color='k', lw=1., zorder=6)
half_max = np.max(np.abs(wavelet)) / 2.
ax.plot([-this_fwhm / 2., this_fwhm / 2.], [half_max, half_max],
color='k', linestyle='-', label='FWHM', zorder=6)
ax.legend(loc='upper right')
ax.set(xlabel='Time (s)', ylabel='Amplitude')
""" # noqa: E501
Ws = list()
n_cycles = np.array(n_cycles, float).ravel()
freqs = np.array(freqs, float)
if np.any(freqs <= 0):
raise ValueError("all frequencies in 'freqs' must be greater than 0.")
if (n_cycles.size != 1) and (n_cycles.size != len(freqs)):
raise ValueError("n_cycles should be fixed or defined for each frequency.")
_check_option("freqs.ndim", freqs.ndim, [0, 1])
singleton = freqs.ndim == 0
if singleton:
freqs = freqs[np.newaxis]
for k, f in enumerate(freqs):
if len(n_cycles) != 1:
this_n_cycles = n_cycles[k]
else:
this_n_cycles = n_cycles[0]
# sigma_t is the stddev of gaussian window in the time domain; can be
# scale-dependent or fixed across freqs
if sigma is None:
sigma_t = this_n_cycles / (2.0 * np.pi * f)
else:
sigma_t = this_n_cycles / (2.0 * np.pi * sigma)
# time vector. We go 5 standard deviations out to make sure we're
# *very* close to zero at the ends. We also make sure that there's a
# sample at exactly t=0
t = np.arange(0.0, 5.0 * sigma_t, 1.0 / sfreq)
t = np.r_[-t[::-1], t[1:]]
oscillation = np.exp(2.0 * 1j * np.pi * f * t)
if zero_mean:
# this offset is equivalent to the κ_σ term in Wikipedia's
# equations, and satisfies the "admissibility criterion" for CWTs
real_offset = np.exp(-2 * (np.pi * f * sigma_t) ** 2)
oscillation -= real_offset
gaussian_envelope = np.exp(-(t**2) / (2.0 * sigma_t**2))
W = oscillation * gaussian_envelope
# the scaling factor here is proportional to what is used in
# Tallon-Baudry 1997: (sigma_t*sqrt(pi))^(-1/2). It yields a wavelet
# with norm sqrt(2) for the full wavelet / norm 1 for the real part
W /= np.sqrt(0.5) * np.linalg.norm(W.ravel())
Ws.append(W)
if singleton:
Ws = Ws[0]
return Ws
def fwhm(freq, n_cycles):
"""Compute the full-width half maximum of a Morlet wavelet.
Uses the formula from :footcite:t:`Cohen2019`.
Parameters
----------
freq : float
The oscillation frequency of the wavelet.
n_cycles : float
The duration of the wavelet, expressed as the number of oscillation
cycles.
Returns
-------
fwhm : float
The full-width half maximum of the wavelet.
Notes
-----
.. versionadded:: 1.3
References
----------
.. footbibliography::
"""
return n_cycles * np.sqrt(2 * np.log(2)) / (np.pi * freq)
def _make_dpss(
sfreq,
freqs,
n_cycles=7.0,
time_bandwidth=4.0,
zero_mean=False,
return_weights=False,
):
"""Compute DPSS tapers for the given frequency range.
Parameters
----------
sfreq : float
The sampling frequency.
freqs : ndarray, shape (n_freqs,)
The frequencies in Hz.
n_cycles : float | ndarray, shape (n_freqs,), default 7.
The number of cycles globally or for each frequency.
time_bandwidth : float, default 4.0
Time x Bandwidth product.
The number of good tapers (low-bias) is chosen automatically based on
this to equal floor(time_bandwidth - 1).
Default is 4.0, giving 3 good tapers.
zero_mean : bool | None, , default False
Make sure the wavelet has a mean of zero.
return_weights : bool
Whether to return the concentration weights.
Returns
-------
Ws : list of array
The wavelets time series.
"""
Ws = list()
freqs = np.array(freqs)
if np.any(freqs <= 0):
raise ValueError("all frequencies in 'freqs' must be greater than 0.")
if time_bandwidth < 2.0:
raise ValueError("time_bandwidth should be >= 2.0 for good tapers")
n_taps = int(np.floor(time_bandwidth - 1))
n_cycles = np.atleast_1d(n_cycles)
if n_cycles.size != 1 and n_cycles.size != len(freqs):
raise ValueError("n_cycles should be fixed or defined for each frequency.")
for m in range(n_taps):
Wm = list()
for k, f in enumerate(freqs):
if len(n_cycles) != 1:
this_n_cycles = n_cycles[k]
else:
this_n_cycles = n_cycles[0]
t_win = this_n_cycles / float(f)
t = np.arange(0.0, t_win, 1.0 / sfreq)
# Making sure wavelets are centered before tapering
oscillation = np.exp(2.0 * 1j * np.pi * f * (t - t_win / 2.0))
# Get dpss tapers
tapers, conc = dpss_windows(
t.shape[0], time_bandwidth / 2.0, n_taps, sym=False
)
Wk = oscillation * tapers[m]
if zero_mean: # to make it zero mean
real_offset = Wk.mean()
Wk -= real_offset
Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel())
Wm.append(Wk)
Ws.append(Wm)
if return_weights:
return Ws, conc
return Ws
# Low level convolution
def _get_nfft(wavelets, X, use_fft=True, check=True):
n_times = X.shape[-1]
max_size = max(w.size for w in wavelets)
if max_size > n_times:
msg = (
f"At least one of the wavelets ({max_size}) is longer than the "
f"signal ({n_times}). Consider using a longer signal or "
"shorter wavelets."
)
if check:
if use_fft:
warn(msg, UserWarning)
else:
raise ValueError(msg)
nfft = n_times + max_size - 1
nfft = next_fast_len(nfft) # 2 ** int(np.ceil(np.log2(nfft)))
return nfft
def _cwt_gen(X, Ws, *, fsize=0, mode="same", decim=1, use_fft=True):
"""Compute cwt with fft based convolutions or temporal convolutions.
Parameters
----------
X : array of shape (n_signals, n_times)
The data.
Ws : list of array
Wavelets time series.
fsize : int
FFT length.
mode : {'full', 'valid', 'same'}
See numpy.convolve.
decim : int | slice, default 1
To reduce memory usage, decimation factor after time-frequency
decomposition.
If `int`, returns tfr[..., ::decim].
If `slice`, returns tfr[..., decim].
.. note:: Decimation may create aliasing artifacts.
use_fft : bool, default True
Use the FFT for convolutions or not.
Returns
-------
out : array, shape (n_signals, n_freqs, n_time_decim)
The time-frequency transform of the signals.
"""
_check_option("mode", mode, ["same", "valid", "full"])
decim = _ensure_slice(decim)
X = np.asarray(X)
# Precompute wavelets for given frequency range to save time
_, n_times = X.shape
n_times_out = X[:, decim].shape[1]
n_freqs = len(Ws)
# precompute FFTs of Ws
if use_fft:
fft_Ws = np.empty((n_freqs, fsize), dtype=np.complex128)
for i, W in enumerate(Ws):
fft_Ws[i] = fft(W, fsize)
# Make generator looping across signals
tfr = np.zeros((n_freqs, n_times_out), dtype=np.complex128)
for x in X:
if use_fft:
fft_x = fft(x, fsize)
# Loop across wavelets
for ii, W in enumerate(Ws):
if use_fft:
ret = ifft(fft_x * fft_Ws[ii])[: n_times + W.size - 1]
else:
# Work around multarray.correlate->OpenBLAS bug on ppc64le
# ret = np.correlate(x, W, mode=mode)
ret = np.convolve(x, W.real, mode=mode) + 1j * np.convolve(
x, W.imag, mode=mode
)
# Center and decimate decomposition
if mode == "valid":
sz = int(abs(W.size - n_times)) + 1
offset = (n_times - sz) // 2
this_slice = slice(offset // decim.step, (offset + sz) // decim.step)
if use_fft:
ret = _centered(ret, sz)
tfr[ii, this_slice] = ret[decim]
elif mode == "full" and not use_fft:
start = (W.size - 1) // 2
end = len(ret) - (W.size // 2)
ret = ret[start:end]
tfr[ii, :] = ret[decim]
else:
if use_fft:
ret = _centered(ret, n_times)
tfr[ii, :] = ret[decim]
yield tfr
# Loop of convolution: single trial
def _compute_tfr(
epoch_data,
freqs,
sfreq=1.0,
method="morlet",
n_cycles=7.0,
zero_mean=None,
time_bandwidth=None,
use_fft=True,
decim=1,
output="complex",
n_jobs=None,
*,
verbose=None,
):
"""Compute time-frequency transforms.
Parameters
----------
epoch_data : array of shape (n_epochs, n_channels, n_times)
The epochs.default ``'complex'``
freqs : array-like of floats, shape (n_freqs)
The frequencies.
sfreq : float | int, default 1.0
Sampling frequency of the data.
method : 'multitaper' | 'morlet', default 'morlet'
The time-frequency method. 'morlet' convolves a Morlet wavelet.
'multitaper' uses complex exponentials windowed with multiple DPSS
tapers.
n_cycles : float | array of float, default 7.0
Number of cycles in the wavelet. Fixed number
or one per frequency.
zero_mean : bool | None, default None
None means True for method='multitaper' and False for method='morlet'.
If True, make sure the wavelets have a mean of zero.
time_bandwidth : float, default None
If None and method=multitaper, will be set to 4.0 (3 tapers).
Time x (Full) Bandwidth product. Only applies if
method == 'multitaper'. The number of good tapers (low-bias) is
chosen automatically based on this to equal floor(time_bandwidth - 1).
use_fft : bool, default True
Use the FFT for convolutions or not.
decim : int | slice, default 1
To reduce memory usage, decimation factor after time-frequency
decomposition.
If `int`, returns tfr[..., ::decim].
If `slice`, returns tfr[..., decim].
.. note::
Decimation may create aliasing artifacts, yet decimation
is done after the convolutions.
output : str
* 'complex' (default) : single trial complex.
* 'power' : single trial power.
* 'phase' : single trial phase.
* 'avg_power' : average of single trial power.
* 'itc' : inter-trial coherence.
* 'avg_power_itc' : average of single trial power and inter-trial
coherence across trials.
%(n_jobs)s
The number of epochs to process at the same time. The parallelization
is implemented across channels.
%(verbose)s
Returns
-------
out : array
Time frequency transform of epoch_data. If output is in ['complex',
'phase', 'power'], then shape of ``out`` is ``(n_epochs, n_chans,
n_freqs, n_times)``, else it is ``(n_chans, n_freqs, n_times)``.
However, using multitaper method and output ``'complex'`` or
``'phase'`` results in shape of ``out`` being ``(n_epochs, n_chans,
n_tapers, n_freqs, n_times)``. If output is ``'avg_power_itc'``, the
real values in the ``output`` contain average power' and the imaginary
values contain the ITC: ``out = avg_power + i * itc``.
"""
# Check data
epoch_data = np.asarray(epoch_data)
if epoch_data.ndim != 3:
raise ValueError(
"epoch_data must be of shape (n_epochs, n_chans, "
f"n_times), got {epoch_data.shape}"
)
# Check params
freqs, sfreq, zero_mean, n_cycles, time_bandwidth, decim = _check_tfr_param(
freqs,
sfreq,
method,
zero_mean,
n_cycles,
time_bandwidth,
use_fft,
decim,
output,
)
decim = _ensure_slice(decim)
if (freqs > sfreq / 2.0).any():
raise ValueError(
"Cannot compute freq above Nyquist freq of the data "
f"({sfreq / 2.0:0.1f} Hz), got {freqs.max():0.1f} Hz"
)
# We decimate *after* decomposition, so we need to create our kernels
# for the original sfreq
if method == "morlet":
W = morlet(sfreq, freqs, n_cycles=n_cycles, zero_mean=zero_mean)
Ws = [W] # to have same dimensionality as the 'multitaper' case
elif method == "multitaper":
Ws = _make_dpss(
sfreq,
freqs,
n_cycles=n_cycles,
time_bandwidth=time_bandwidth,
zero_mean=zero_mean,
)
# Check wavelets
if len(Ws[0][0]) > epoch_data.shape[2]:
raise ValueError(
"At least one of the wavelets is longer than the "
"signal. Use a longer signal or shorter wavelets."
)
# Initialize output
n_freqs = len(freqs)
n_tapers = len(Ws)
n_epochs, n_chans, n_times = epoch_data[:, :, decim].shape
if output in ("power", "phase", "avg_power", "itc"):
dtype = np.float64
elif output in ("complex", "avg_power_itc"):
# avg_power_itc is stored as power + 1i * itc to keep a
# simple dimensionality
dtype = np.complex128
if ("avg_" in output) or ("itc" in output):
out = np.empty((n_chans, n_freqs, n_times), dtype)
elif output in ["complex", "phase"] and method == "multitaper":
out = np.empty((n_chans, n_tapers, n_epochs, n_freqs, n_times), dtype)
else:
out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype)
# Parallel computation
all_Ws = sum([list(W) for W in Ws], list())
_get_nfft(all_Ws, epoch_data, use_fft)
parallel, my_cwt, n_jobs = parallel_func(_time_frequency_loop, n_jobs)
# Parallelization is applied across channels.
tfrs = parallel(
my_cwt(channel, Ws, output, use_fft, "same", decim, method)
for channel in epoch_data.transpose(1, 0, 2)
)
# FIXME: to avoid overheads we should use np.array_split()
for channel_idx, tfr in enumerate(tfrs):
out[channel_idx] = tfr
if ("avg_" not in output) and ("itc" not in output):
# This is to enforce that the first dimension is for epochs
if output in ["complex", "phase"] and method == "multitaper":
out = out.transpose(2, 0, 1, 3, 4)
else:
out = out.transpose(1, 0, 2, 3)
return out
def _check_tfr_param(
freqs, sfreq, method, zero_mean, n_cycles, time_bandwidth, use_fft, decim, output
):
"""Aux. function to _compute_tfr to check the params validity."""
# Check freqs
if not isinstance(freqs, (list, np.ndarray)):
raise ValueError(f"freqs must be an array-like, got {type(freqs)} instead.")
freqs = np.asarray(freqs, dtype=float)
if freqs.ndim != 1:
raise ValueError(
f"freqs must be of shape (n_freqs,), got {np.array(freqs.shape)} "
"instead."
)
# Check sfreq
if not isinstance(sfreq, (float, int)):
raise ValueError(f"sfreq must be a float or an int, got {type(sfreq)} instead.")
sfreq = float(sfreq)
# Default zero_mean = True if multitaper else False
zero_mean = method == "multitaper" if zero_mean is None else zero_mean
if not isinstance(zero_mean, bool):
raise ValueError(
f"zero_mean should be of type bool, got {type(zero_mean)}. instead"
)
freqs = np.asarray(freqs)
# Check n_cycles
if isinstance(n_cycles, (int, float)):
n_cycles = float(n_cycles)
elif isinstance(n_cycles, (list, np.ndarray)):
n_cycles = np.array(n_cycles)
if len(n_cycles) != len(freqs):
raise ValueError(
"n_cycles must be a float or an array of length "
f"{len(freqs)} frequencies, got {len(n_cycles)} cycles instead."
)
else:
raise ValueError(
f"n_cycles must be a float or an array, got {type(n_cycles)} instead."
)
# Check time_bandwidth
if (method == "morlet") and (time_bandwidth is not None):
raise ValueError('time_bandwidth only applies to "multitaper" method.')
elif method == "multitaper":
time_bandwidth = 4.0 if time_bandwidth is None else float(time_bandwidth)
# Check use_fft
if not isinstance(use_fft, bool):
raise ValueError(f"use_fft must be a boolean, got {type(use_fft)} instead.")
# Check decim
if isinstance(decim, int):
decim = slice(None, None, decim)
if not isinstance(decim, slice):
raise ValueError(
f"decim must be an integer or a slice, got {type(decim)} instead."
)
# Check output
_check_option(
"output",
output,
["complex", "power", "phase", "avg_power_itc", "avg_power", "itc"],
)
_check_option("method", method, ["multitaper", "morlet"])
return freqs, sfreq, zero_mean, n_cycles, time_bandwidth, decim
def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
"""Aux. function to _compute_tfr.
Loops time-frequency transform across wavelets and epochs.
Parameters
----------
X : array, shape (n_epochs, n_times)
The epochs data of a single channel.
Ws : list, shape (n_tapers, n_wavelets, n_times)
The wavelets.
output : str
* 'complex' : single trial complex.
* 'power' : single trial power.
* 'phase' : single trial phase.
* 'avg_power' : average of single trial power.
* 'itc' : inter-trial coherence.
* 'avg_power_itc' : average of single trial power and inter-trial
coherence across trials.
use_fft : bool
Use the FFT for convolutions or not.
mode : {'full', 'valid', 'same'}
See numpy.convolve.
decim : slice
The decimation slice: e.g. power[:, decim]
method : str | None
Used only for multitapering to create tapers dimension in the output
if ``output in ['complex', 'phase']``.
"""
# Set output type
dtype = np.float64
if output in ["complex", "avg_power_itc"]:
dtype = np.complex128
# Init outputs
decim = _ensure_slice(decim)
n_tapers = len(Ws)
n_epochs, n_times = X[:, decim].shape
n_freqs = len(Ws[0])
if ("avg_" in output) or ("itc" in output):
tfrs = np.zeros((n_freqs, n_times), dtype=dtype)
elif output in ["complex", "phase"] and method == "multitaper":
tfrs = np.zeros((n_tapers, n_epochs, n_freqs, n_times), dtype=dtype)
else:
tfrs = np.zeros((n_epochs, n_freqs, n_times), dtype=dtype)
# Loops across tapers.
for taper_idx, W in enumerate(Ws):
# No need to check here, it's done earlier (outside parallel part)
nfft = _get_nfft(W, X, use_fft, check=False)
coefs = _cwt_gen(X, W, fsize=nfft, mode=mode, decim=decim, use_fft=use_fft)
# Inter-trial phase locking is apparently computed per taper...
if "itc" in output:
plf = np.zeros((n_freqs, n_times), dtype=np.complex128)
# Loop across epochs
for epoch_idx, tfr in enumerate(coefs):
# Transform complex values
if output in ["power", "avg_power"]:
tfr = (tfr * tfr.conj()).real # power
elif output == "phase":
tfr = np.angle(tfr)
elif output == "avg_power_itc":
tfr_abs = np.abs(tfr)
plf += tfr / tfr_abs # phase
tfr = tfr_abs**2 # power
elif output == "itc":
plf += tfr / np.abs(tfr) # phase
continue # not need to stack anything else than plf
# Stack or add
if ("avg_" in output) or ("itc" in output):
tfrs += tfr
elif output in ["complex", "phase"] and method == "multitaper":
tfrs[taper_idx, epoch_idx] += tfr
else:
tfrs[epoch_idx] += tfr
# Compute inter trial coherence
if output == "avg_power_itc":
tfrs += 1j * np.abs(plf)
elif output == "itc":
tfrs += np.abs(plf)
# Normalization of average metrics
if ("avg_" in output) or ("itc" in output):
tfrs /= n_epochs
# Normalization by number of taper
if n_tapers > 1 and output not in ["complex", "phase"]:
tfrs /= n_tapers
return tfrs
@fill_doc
def cwt(X, Ws, use_fft=True, mode="same", decim=1):
"""Compute time-frequency decomposition with continuous wavelet transform.
Parameters
----------
X : array, shape (n_signals, n_times)
The signals.
Ws : list of array
Wavelets time series.
use_fft : bool
Use FFT for convolutions. Defaults to True.
mode : 'same' | 'valid' | 'full'
Convention for convolution. 'full' is currently not implemented with
``use_fft=False``. Defaults to ``'same'``.
%(decim_tfr)s
Returns
-------
tfr : array, shape (n_signals, n_freqs, n_times)
The time-frequency decompositions.
See Also
--------
mne.time_frequency.tfr_morlet : Compute time-frequency decomposition
with Morlet wavelets.
"""
nfft = _get_nfft(Ws, X, use_fft)
return _cwt_array(X, Ws, nfft, mode, decim, use_fft)
def _cwt_array(X, Ws, nfft, mode, decim, use_fft):
decim = _ensure_slice(decim)
coefs = _cwt_gen(X, Ws, fsize=nfft, mode=mode, decim=decim, use_fft=use_fft)
n_signals, n_times = X[:, decim].shape
tfrs = np.empty((n_signals, len(Ws), n_times), dtype=np.complex128)
for k, tfr in enumerate(coefs):
tfrs[k] = tfr
return tfrs
def _tfr_aux(
method, inst, freqs, decim, return_itc, picks, average, output, **tfr_params
):
from ..epochs import BaseEpochs
kwargs = dict(
method=method,
freqs=freqs,
picks=picks,
decim=decim,
output=output,
**tfr_params,
)
if isinstance(inst, BaseEpochs):
kwargs.update(average=average, return_itc=return_itc)
elif average:
logger.info("inst is Evoked, setting `average=False`")
average = False
if average and output == "complex":
raise ValueError('output must be "power" if average=True')
if not average and return_itc:
raise ValueError("Inter-trial coherence is not supported with average=False")
return inst.compute_tfr(**kwargs)
@legacy(alt='.compute_tfr(method="morlet")')
@verbose
def tfr_morlet(
inst,
freqs,
n_cycles,
use_fft=False,
return_itc=True,
decim=1,
n_jobs=None,
picks=None,
zero_mean=True,
average=True,
output="power",
verbose=None,
):
"""Compute Time-Frequency Representation (TFR) using Morlet wavelets.
Same computation as `~mne.time_frequency.tfr_array_morlet`, but
operates on `~mne.Epochs` or `~mne.Evoked` objects instead of
:class:`NumPy arrays <numpy.ndarray>`.
Parameters
----------
inst : Epochs | Evoked
The epochs or evoked object.
%(freqs_tfr_array)s
%(n_cycles_tfr)s
use_fft : bool, default False
The fft based convolution or not.
return_itc : bool, default True
Return inter-trial coherence (ITC) as well as averaged power.
Must be ``False`` for evoked data.
%(decim_tfr)s
%(n_jobs)s
picks : array-like of int | None, default None
The indices of the channels to decompose. If None, all available
good data channels are decomposed.
zero_mean : bool, default True
Make sure the wavelet has a mean of zero.
.. versionadded:: 0.13.0
%(average_tfr)s
output : str
Can be ``"power"`` (default) or ``"complex"``. If ``"complex"``, then
``average`` must be ``False``.
.. versionadded:: 0.15.0
%(verbose)s
Returns
-------
power : AverageTFR | EpochsTFR
The averaged or single-trial power.
itc : AverageTFR | EpochsTFR
The inter-trial coherence (ITC). Only returned if return_itc
is True.
See Also
--------
mne.time_frequency.tfr_array_morlet
mne.time_frequency.tfr_multitaper
mne.time_frequency.tfr_array_multitaper
mne.time_frequency.tfr_stockwell
mne.time_frequency.tfr_array_stockwell
Notes
-----
%(morlet_reference)s
%(temporal_window_tfr_intro)s
%(temporal_window_tfr_morlet_notes)s
See :func:`mne.time_frequency.morlet` for more information about the
Morlet wavelet.
References
----------
.. footbibliography::
"""
tfr_params = dict(
n_cycles=n_cycles,
n_jobs=n_jobs,
use_fft=use_fft,
zero_mean=zero_mean,
output=output,
)
return _tfr_aux(
"morlet", inst, freqs, decim, return_itc, picks, average, **tfr_params
)
@verbose
def tfr_array_morlet(
data,
sfreq,
freqs,
n_cycles=7.0,
zero_mean=True,
use_fft=True,
decim=1,
output="complex",
n_jobs=None,
*,
verbose=None,
):
"""Compute Time-Frequency Representation (TFR) using Morlet wavelets.
Same computation as `~mne.time_frequency.tfr_morlet`, but operates on
:class:`NumPy arrays <numpy.ndarray>` instead of `~mne.Epochs` objects.
Parameters
----------
data : array of shape (n_epochs, n_channels, n_times)
The epochs.
sfreq : float | int
Sampling frequency of the data.
%(freqs_tfr_array)s
%(n_cycles_tfr)s
zero_mean : bool | None
If True, make sure the wavelets have a mean of zero. default False.
.. versionchanged:: 1.8
The default will change from ``zero_mean=False`` in 1.6 to ``True`` in
1.8.
use_fft : bool
Use the FFT for convolutions or not. default True.
%(decim_tfr)s
output : str, default ``'complex'``
* ``'complex'`` : single trial complex.
* ``'power'`` : single trial power.
* ``'phase'`` : single trial phase.
* ``'avg_power'`` : average of single trial power.
* ``'itc'`` : inter-trial coherence.
* ``'avg_power_itc'`` : average of single trial power and inter-trial
coherence across trials.
%(n_jobs)s
The number of epochs to process at the same time. The parallelization
is implemented across channels. Default 1.
%(verbose)s
Returns
-------
out : array
Time frequency transform of ``data``.
- if ``output in ('complex', 'phase', 'power')``, array of shape
``(n_epochs, n_chans, n_freqs, n_times)``
- else, array of shape ``(n_chans, n_freqs, n_times)``
If ``output`` is ``'avg_power_itc'``, the real values in ``out``
contain the average power and the imaginary values contain the ITC:
:math:`out = power_{avg} + i * itc`.
See Also
--------
mne.time_frequency.tfr_morlet
mne.time_frequency.tfr_multitaper
mne.time_frequency.tfr_array_multitaper
mne.time_frequency.tfr_stockwell
mne.time_frequency.tfr_array_stockwell
Notes
-----
%(morlet_reference)s
%(temporal_window_tfr_intro)s
%(temporal_window_tfr_morlet_notes)s
.. versionadded:: 0.14.0
References
----------
.. footbibliography::
"""
return _compute_tfr(
epoch_data=data,
freqs=freqs,
sfreq=sfreq,
method="morlet",
n_cycles=n_cycles,
zero_mean=zero_mean,
time_bandwidth=None,
use_fft=use_fft,
decim=decim,
output=output,
n_jobs=n_jobs,
verbose=verbose,
)
@legacy(alt='.compute_tfr(method="multitaper")')
@verbose
def tfr_multitaper(
inst,
freqs,
n_cycles,
time_bandwidth=4.0,
use_fft=True,
return_itc=True,
decim=1,
n_jobs=None,
picks=None,
average=True,
*,
verbose=None,
):
"""Compute Time-Frequency Representation (TFR) using DPSS tapers.
Same computation as :func:`~mne.time_frequency.tfr_array_multitaper`, but
operates on :class:`~mne.Epochs` or :class:`~mne.Evoked` objects instead of
:class:`NumPy arrays <numpy.ndarray>`.
Parameters
----------
inst : Epochs | Evoked
The epochs or evoked object.
%(freqs_tfr_array)s
%(n_cycles_tfr)s
%(time_bandwidth_tfr)s
use_fft : bool, default True
The fft based convolution or not.
return_itc : bool, default True
Return inter-trial coherence (ITC) as well as averaged (or
single-trial) power.
%(decim_tfr)s
%(n_jobs)s
%(picks_good_data)s
%(average_tfr)s
%(verbose)s
Returns
-------
power : AverageTFR | EpochsTFR
The averaged or single-trial power.
itc : AverageTFR | EpochsTFR
The inter-trial coherence (ITC). Only returned if return_itc
is True.
See Also
--------
mne.time_frequency.tfr_array_multitaper
mne.time_frequency.tfr_stockwell
mne.time_frequency.tfr_array_stockwell
mne.time_frequency.tfr_morlet
mne.time_frequency.tfr_array_morlet
Notes
-----
%(temporal_window_tfr_intro)s
%(temporal_window_tfr_multitaper_notes)s
%(time_bandwidth_tfr_notes)s
.. versionadded:: 0.9.0
"""
from ..epochs import EpochsArray
from ..evoked import Evoked
tfr_params = dict(
n_cycles=n_cycles,
n_jobs=n_jobs,
use_fft=use_fft,
zero_mean=True,
time_bandwidth=time_bandwidth,
)
if isinstance(inst, Evoked) and not average:
# convert AverageTFR to EpochsTFR for backwards compatibility
inst = EpochsArray(inst.data[np.newaxis], inst.info, tmin=inst.tmin, proj=False)
return _tfr_aux(
method="multitaper",
inst=inst,
freqs=freqs,
decim=decim,
return_itc=return_itc,
picks=picks,
average=average,
output="power",
**tfr_params,
)
# TFR(s) class
@fill_doc
class BaseTFR(ContainsMixin, UpdateChannelsMixin, SizeMixin, ExtendedTimeMixin):
"""Base class for RawTFR, EpochsTFR, and AverageTFR (for type checking only).
.. note::
This class should not be instantiated directly; it is provided in the public API
only for type-checking purposes (e.g., ``isinstance(my_obj, BaseTFR)``). To
create TFR objects, use the ``.compute_tfr()`` methods on :class:`~mne.io.Raw`,
:class:`~mne.Epochs`, or :class:`~mne.Evoked`, or use the constructors listed
below under "See Also".
Parameters
----------
inst : instance of Raw, Epochs, or Evoked
The data from which to compute the time-frequency representation.
%(method_tfr)s
%(freqs_tfr)s
%(tmin_tmax_psd)s
%(picks_good_data_noref)s
%(proj_psd)s
%(decim_tfr)s
%(n_jobs)s
%(reject_by_annotation_tfr)s
%(verbose)s
%(method_kw_tfr)s
See Also
--------
mne.time_frequency.RawTFR
mne.time_frequency.RawTFRArray
mne.time_frequency.EpochsTFR
mne.time_frequency.EpochsTFRArray
mne.time_frequency.AverageTFR
mne.time_frequency.AverageTFRArray
"""
def __init__(
self,
inst,
method,
freqs,
tmin,
tmax,
picks,
proj,
*,
decim,
n_jobs,
reject_by_annotation=None,
verbose=None,
**method_kw,
):
from ..epochs import BaseEpochs
from ._stockwell import tfr_array_stockwell
# triage reading from file
if isinstance(inst, dict):
self.__setstate__(inst)
return
if method is None or freqs is None:
problem = [
f"{k}=None"
for k, v in dict(method=method, freqs=freqs).items()
if v is None
]
# TODO when py3.11 is min version, replace if/elif/else block with
# classname = inspect.currentframe().f_back.f_code.co_qualname.split(".")[0]
_varnames = inspect.currentframe().f_back.f_code.co_varnames
if "BaseRaw" in _varnames:
classname = "RawTFR"
elif "Evoked" in _varnames:
classname = "AverageTFR"
else:
assert "BaseEpochs" in _varnames and "Evoked" not in _varnames
classname = "EpochsTFR"
# end TODO
raise ValueError(
f'{classname} got unsupported parameter value{_pl(problem)} '
f'{" and ".join(problem)}.'
)
# shim for tfr_array_morlet deprecation warning (TODO: remove after 1.7 release)
if method == "morlet":
method_kw.setdefault("zero_mean", True)
# check method
valid_methods = ["morlet", "multitaper"]
if isinstance(inst, BaseEpochs):
valid_methods.append("stockwell")
method = _check_option("method", method, valid_methods)
# for stockwell, `tmin, tmax` already added to `method_kw` by calling method,
# and `freqs` vector has been pre-computed
if method != "stockwell":
method_kw.update(freqs=freqs)
# ↓↓↓ if constructor called directly, prevents key error
method_kw.setdefault("output", "power")
self._freqs = np.asarray(freqs, dtype=np.float64)
del freqs
# check validity of kwargs manually to save compute time if any are invalid
tfr_funcs = dict(
morlet=tfr_array_morlet,
multitaper=tfr_array_multitaper,
stockwell=tfr_array_stockwell,
)
_check_method_kwargs(tfr_funcs[method], method_kw, msg=f'TFR method "{method}"')
self._tfr_func = partial(tfr_funcs[method], **method_kw)
# apply proj if desired
if proj:
inst = inst.copy().apply_proj()
self.inst = inst
# prep picks and add the info object. bads and non-data channels are dropped by
# _picks_to_idx() so we update the info accordingly:
self._picks = _picks_to_idx(inst.info, picks, "data", with_ref_meg=False)
self.info = pick_info(inst.info, sel=self._picks, copy=True)
# assign some attributes
self._method = method
self._inst_type = type(inst)
self._baseline = None
self.preload = True # needed for __getitem__, never False for TFRs
# self._dims may also get updated by child classes
self._dims = ["channel", "freq", "time"]
self._needs_taper_dim = method == "multitaper" and method_kw["output"] in (
"complex",
"phase",
)
if self._needs_taper_dim:
self._dims.insert(1, "taper")
self._dims = tuple(self._dims)
# get the instance data.
time_mask = _time_mask(inst.times, tmin, tmax, sfreq=self.sfreq)
get_instance_data_kw = dict(time_mask=time_mask)
if reject_by_annotation is not None:
get_instance_data_kw.update(reject_by_annotation=reject_by_annotation)
data = self._get_instance_data(**get_instance_data_kw)
# compute the TFR
self._decim = _ensure_slice(decim)
self._raw_times = inst.times[time_mask]
self._compute_tfr(data, n_jobs, verbose)
self._update_epoch_attributes()
# "apply" decim to the rest of the object (data is decimated in _compute_tfr)
with self.info._unlock():
self.info["sfreq"] /= self._decim.step
_decim_times = inst.times[self._decim]
_decim_time_mask = _time_mask(_decim_times, tmin, tmax, sfreq=self.sfreq)
self._raw_times = _decim_times[_decim_time_mask].copy()
self._set_times(self._raw_times)
self._decim = 1
# record data type (for repr and html_repr). ITC handled in the calling method.
if method == "stockwell":
self._data_type = "Power Estimates"
else:
data_types = dict(
power="Power Estimates",
avg_power="Average Power Estimates",
avg_power_itc="Average Power Estimates",
phase="Phase",
complex="Complex Amplitude",
)
self._data_type = data_types[method_kw["output"]]
# check for correct shape and bad values. `tfr_array_stockwell` doesn't take kw
# `output` so it may be missing here, so use `.get()`
negative_ok = method_kw.get("output", "") in ("complex", "phase")
# if method_kw.get("output", None) in ("phase", "complex"):
# raise RuntimeError
self._check_values(negative_ok=negative_ok)
# we don't need these anymore, and they make save/load harder
del self._picks
del self._tfr_func
del self._needs_taper_dim
del self._shape # calculated from self._data henceforth
del self.inst # save memory
def __abs__(self):
"""Return the absolute value."""
tfr = self.copy()
tfr.data = np.abs(tfr.data)
return tfr
@fill_doc
def __add__(self, other):
"""Add two TFR instances.
%(__add__tfr)s
"""
self._check_compatibility(other)
out = self.copy()
out.data += other.data
return out
@fill_doc
def __iadd__(self, other):
"""Add a TFR instance to another, in-place.
%(__iadd__tfr)s
"""
self._check_compatibility(other)
self.data += other.data
return self
@fill_doc
def __sub__(self, other):
"""Subtract two TFR instances.
%(__sub__tfr)s
"""
self._check_compatibility(other)
out = self.copy()
out.data -= other.data
return out
@fill_doc
def __isub__(self, other):
"""Subtract a TFR instance from another, in-place.
%(__isub__tfr)s
"""
self._check_compatibility(other)
self.data -= other.data
return self
@fill_doc
def __mul__(self, num):
"""Multiply a TFR instance by a scalar.
%(__mul__tfr)s
"""
out = self.copy()
out.data *= num
return out
@fill_doc
def __imul__(self, num):
"""Multiply a TFR instance by a scalar, in-place.
%(__imul__tfr)s
"""
self.data *= num
return self
@fill_doc
def __truediv__(self, num):
"""Divide a TFR instance by a scalar.
%(__truediv__tfr)s
"""
out = self.copy()
out.data /= num
return out
@fill_doc
def __itruediv__(self, num):
"""Divide a TFR instance by a scalar, in-place.
%(__itruediv__tfr)s
"""
self.data /= num
return self
def __eq__(self, other):
"""Test equivalence of two TFR instances."""
return object_diff(vars(self), vars(other)) == ""
def __getstate__(self):
"""Prepare object for serialization."""
return dict(
method=self.method,
data=self._data,
sfreq=self.sfreq,
dims=self._dims,
freqs=self.freqs,
times=self.times,
inst_type_str=_get_instance_type_string(self),
data_type=self._data_type,
info=self.info,
baseline=self._baseline,
decim=self._decim,
)
def __setstate__(self, state):
"""Unpack from serialized format."""
from ..epochs import Epochs
from ..evoked import Evoked
from ..io import Raw
defaults = dict(
method="unknown",
dims=("epoch", "channel", "freq", "time")[-state["data"].ndim :],
baseline=None,
decim=1,
data_type="TFR",
inst_type_str="Unknown",
)
defaults.update(**state)
self._method = defaults["method"]
self._data = defaults["data"]
self._freqs = np.asarray(defaults["freqs"], dtype=np.float64)
self._dims = defaults["dims"]
self._raw_times = np.asarray(defaults["times"], dtype=np.float64)
self._baseline = defaults["baseline"]
self.info = Info(**defaults["info"])
self._data_type = defaults["data_type"]
self._decim = defaults["decim"]
self.preload = True
self._set_times(self._raw_times)
# Handle instance type. Prior to gh-11282, Raw was not a possibility so if
# `inst_type_str` is missing it must be Epochs or Evoked
unknown_class = Epochs if "epoch" in self._dims else Evoked
inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Unknown=unknown_class)
self._inst_type = inst_types[defaults["inst_type_str"]]
# sanity check data/freqs/times/info agreement
self._check_state()
def __repr__(self):
"""Build string representation of the TFR object."""
inst_type_str = _get_instance_type_string(self)
nave = f" (nave={self.nave})" if hasattr(self, "nave") else ""
# shape & dimension names
dims = " × ".join(
[f"{size} {dim}s" for size, dim in zip(self.shape, self._dims)]
)
freq_range = f"{self.freqs[0]:0.1f} - {self.freqs[-1]:0.1f} Hz"
time_range = f"{self.times[0]:0.2f} - {self.times[-1]:0.2f} s"
return (
f"<{self._data_type} from {inst_type_str}{nave}, "
f"{self.method} method | {dims}, {freq_range}, {time_range}, "
f"{sizeof_fmt(self._size)}>"
)
@repr_html
def _repr_html_(self, caption=None):
"""Build HTML representation of the TFR object."""
from ..html_templates import _get_html_template
inst_type_str = _get_instance_type_string(self)
nave = getattr(self, "nave", 0)
t = _get_html_template("repr", "tfr.html.jinja")
t = t.render(tfr=self, inst_type=inst_type_str, nave=nave, caption=caption)
return t
def _check_compatibility(self, other):
"""Check compatibility of two TFR instances, in preparation for arithmetic."""
operation = inspect.currentframe().f_back.f_code.co_name.strip("_")
if operation.startswith("i"):
operation = operation[1:]
msg = f"Cannot {operation} the two TFR instances: {{}} do not match{{}}."
extra = ""
if not isinstance(other, type(self)):
problem = "types"
extra = f" (self is {type(self)}, other is {type(other)})"
elif not self.times.shape == other.times.shape or np.any(
self.times != other.times
):
problem = "times"
elif not self.freqs.shape == other.freqs.shape or np.any(
self.freqs != other.freqs
):
problem = "freqs"
else: # should be OK
return
raise RuntimeError(msg.format(problem, extra))
def _check_state(self):
"""Check data/freqs/times/info agreement during __setstate__."""
msg = "{} axis of data ({}) doesn't match {} attribute ({})"
n_chan_info = len(self.info["chs"])
n_chan = self._data.shape[self._dims.index("channel")]
n_freq = self._data.shape[self._dims.index("freq")]
n_time = self._data.shape[self._dims.index("time")]
if n_chan_info != n_chan:
msg = msg.format("Channel", n_chan, "info", n_chan_info)
elif n_freq != len(self.freqs):
msg = msg.format("Frequency", n_freq, "freqs", self.freqs.size)
elif n_time != len(self.times):
msg = msg.format("Time", n_time, "times", self.times.size)
else:
return
raise ValueError(msg)
def _check_values(self, negative_ok=False):
"""Check TFR results for correct shape and bad values."""
assert len(self._dims) == self._data.ndim
assert self._data.shape == self._shape
# Check for implausible power values: take min() across all but the channel axis
# TODO: should this be more fine-grained (report "chan X in epoch Y")?
ch_dim = self._dims.index("channel")
dims = np.arange(self._data.ndim).tolist()
dims.pop(ch_dim)
negative_values = self._data.min(axis=tuple(dims)) < 0
if negative_values.any() and not negative_ok:
chs = np.array(self.ch_names)[negative_values].tolist()
s = _pl(negative_values.sum())
warn(
f"Negative value in time-frequency decomposition for channel{s} "
f'{", ".join(chs)}',
UserWarning,
)
def _compute_tfr(self, data, n_jobs, verbose):
result = self._tfr_func(
data,
self.sfreq,
decim=self._decim,
n_jobs=n_jobs,
verbose=verbose,
)
# assign ._data and maybe ._itc
# tfr_array_stockwell always returns ITC (sometimes it's None)
if self.method == "stockwell":
self._data, self._itc, freqs = result
assert np.array_equal(self._freqs, freqs)
elif self._tfr_func.keywords.get("output", "").endswith("_itc"):
self._data, self._itc = result.real, result.imag
else:
self._data = result
# remove fake "epoch" dimension
if self.method != "stockwell" and _get_instance_type_string(self) != "Epochs":
self._data = np.squeeze(self._data, axis=0)
# this is *expected* shape, it gets asserted later in _check_values()
# (and then deleted afterwards)
expected_shape = [
len(self.ch_names),
len(self.freqs),
len(self._raw_times[self._decim]), # don't use self.times, not set yet
]
# deal with the "taper" dimension
if self._needs_taper_dim:
expected_shape.insert(1, self._data.shape[1])
self._shape = tuple(expected_shape)
@verbose
def _onselect(
self,
eclick,
erelease,
picks=None,
exclude="bads",
combine="mean",
baseline=None,
mode=None,
cmap=None,
source_plot_joint=False,
topomap_args=None,
verbose=None,
):
"""Respond to rectangle selector in TFR image plots with a topomap plot."""
if abs(eclick.x - erelease.x) < 0.1 or abs(eclick.y - erelease.y) < 0.1:
return
t_range = (min(eclick.xdata, erelease.xdata), max(eclick.xdata, erelease.xdata))
f_range = (min(eclick.ydata, erelease.ydata), max(eclick.ydata, erelease.ydata))
# snap to nearest measurement point
t_idx = np.abs(self.times - np.atleast_2d(t_range).T).argmin(axis=1)
f_idx = np.abs(self.freqs - np.atleast_2d(f_range).T).argmin(axis=1)
tmin, tmax = self.times[t_idx]
fmin, fmax = self.freqs[f_idx]
# immutable → mutable default
if topomap_args is None:
topomap_args = dict()
topomap_args.setdefault("cmap", cmap)
topomap_args.setdefault("vlim", (None, None))
# figure out which channel types we're dealing with
types = list()
if "eeg" in self:
types.append("eeg")
if "mag" in self:
types.append("mag")
if "grad" in self:
grad_picks = _pair_grad_sensors(
self.info, topomap_coords=False, raise_error=False
)
if len(grad_picks) > 1:
types.append("grad")
elif len(types) == 0:
logger.info(
"Need at least 2 gradiometer pairs to plot a gradiometer topomap."
)
return # Don't draw a figure for nothing.
fig = figure_nobar()
t_range = f"{tmin:.3f}" if tmin == tmax else f"{tmin:.3f} - {tmax:.3f}"
f_range = f"{fmin:.2f}" if fmin == fmax else f"{fmin:.2f} - {fmax:.2f}"
fig.suptitle(f"{t_range} s,\n{f_range} Hz")
if source_plot_joint:
ax = fig.add_subplot()
data, times, freqs = self.get_data(
picks=picks, exclude=exclude, return_times=True, return_freqs=True
)
# merge grads before baselining (makes ERDs visible)
ch_types = np.array(self.get_channel_types(unique=True))
ch_type = ch_types.item() # will error if there are more than one
data, pos = _merge_if_grads(
data=data,
info=self.info,
ch_type=ch_type,
sphere=topomap_args.get("sphere"),
combine=combine,
)
# baseline and crop
data, *_ = _prep_data_for_plot(
data,
times,
freqs,
tmin=tmin,
tmax=tmax,
fmin=fmin,
fmax=fmax,
baseline=baseline,
mode=mode,
verbose=verbose,
)
# average over times and freqs
data = data.mean((-2, -1))
im, _ = plot_topomap(data, pos, axes=ax, show=False, **topomap_args)
_add_colorbar(ax, im, topomap_args["cmap"], title="AU")
plt_show(fig=fig)
else:
for idx, ch_type in enumerate(types):
ax = fig.add_subplot(1, len(types), idx + 1)
plot_tfr_topomap(
self,
ch_type=ch_type,
tmin=tmin,
tmax=tmax,
fmin=fmin,
fmax=fmax,
baseline=baseline,
mode=mode,
axes=ax,
**topomap_args,
)
ax.set_title(ch_type)
def _update_epoch_attributes(self):
# overwritten in EpochsTFR; adds things needed for to_data_frame and __getitem__
pass
@property
def _detrend_picks(self):
"""Provide compatibility with __iter__."""
return list()
@property
def baseline(self):
"""Start and end of the baseline period (in seconds)."""
return self._baseline
@property
def ch_names(self):
"""The channel names."""
return self.info["ch_names"]
@property
def data(self):
"""The time-frequency-resolved power estimates."""
return self._data
@data.setter
def data(self, data):
self._data = data
@property
def freqs(self):
"""The frequencies at which power estimates were computed."""
return self._freqs
@property
def method(self):
"""The method used to compute the time-frequency power estimates."""
return self._method
@property
def sfreq(self):
"""Sampling frequency of the data."""
return self.info["sfreq"]
@property
def shape(self):
"""Data shape."""
return self._data.shape
@property
def times(self):
"""The time points present in the data (in seconds)."""
return self._times_readonly
@fill_doc
def crop(self, tmin=None, tmax=None, fmin=None, fmax=None, include_tmax=True):
"""Crop data to a given time interval in place.
Parameters
----------
%(tmin_tmax_psd)s
fmin : float | None
Lowest frequency of selection in Hz.
.. versionadded:: 0.18.0
fmax : float | None
Highest frequency of selection in Hz.
.. versionadded:: 0.18.0
%(include_tmax)s
Returns
-------
%(inst_tfr)s
The modified instance.
"""
super().crop(tmin=tmin, tmax=tmax, include_tmax=include_tmax)
if fmin is not None or fmax is not None:
freq_mask = _freq_mask(
self.freqs, sfreq=self.info["sfreq"], fmin=fmin, fmax=fmax
)
else:
freq_mask = slice(None)
self._freqs = self.freqs[freq_mask]
# Deal with broadcasting (boolean arrays do not broadcast, but indices
# do, so we need to convert freq_mask to make use of broadcasting)
if isinstance(freq_mask, np.ndarray):
freq_mask = np.where(freq_mask)[0]
self._data = self._data[..., freq_mask, :]
return self
def copy(self):
"""Return copy of the TFR instance.
Returns
-------
%(inst_tfr)s
A copy of the object.
"""
return deepcopy(self)
@verbose
def apply_baseline(self, baseline, mode="mean", verbose=None):
"""Baseline correct the data.
Parameters
----------
%(baseline_rescale)s
How baseline is computed is determined by the ``mode`` parameter.
mode : 'mean' | 'ratio' | 'logratio' | 'percent' | 'zscore' | 'zlogratio'
Perform baseline correction by
- subtracting the mean of baseline values ('mean')
- dividing by the mean of baseline values ('ratio')
- dividing by the mean of baseline values and taking the log
('logratio')
- subtracting the mean of baseline values followed by dividing by
the mean of baseline values ('percent')
- subtracting the mean of baseline values and dividing by the
standard deviation of baseline values ('zscore')
- dividing by the mean of baseline values, taking the log, and
dividing by the standard deviation of log baseline values
('zlogratio')
%(verbose)s
Returns
-------
%(inst_tfr)s
The modified instance.
"""
self._baseline = _check_baseline(baseline, times=self.times, sfreq=self.sfreq)
rescale(self.data, self.times, self.baseline, mode, copy=False, verbose=verbose)
return self
@fill_doc
def get_data(
self,
picks=None,
exclude="bads",
fmin=None,
fmax=None,
tmin=None,
tmax=None,
return_times=False,
return_freqs=False,
):
"""Get time-frequency data in NumPy array format.
Parameters
----------
%(picks_good_data_noref)s
%(exclude_spectrum_get_data)s
%(fmin_fmax_tfr)s
%(tmin_tmax_psd)s
return_times : bool
Whether to return the time values for the requested time range.
Default is ``False``.
return_freqs : bool
Whether to return the frequency bin values for the requested
frequency range. Default is ``False``.
Returns
-------
data : array
The requested data in a NumPy array.
times : array
The time values for the requested data range. Only returned if
``return_times`` is ``True``.
freqs : array
The frequency values for the requested data range. Only returned if
``return_freqs`` is ``True``.
Notes
-----
Returns a copy of the underlying data (not a view).
"""
tmin = self.times[0] if tmin is None else tmin
tmax = self.times[-1] if tmax is None else tmax
fmin = 0 if fmin is None else fmin
fmax = np.inf if fmax is None else fmax
picks = _picks_to_idx(
self.info, picks, "data_or_ica", exclude=exclude, with_ref_meg=False
)
fmin_idx = np.searchsorted(self.freqs, fmin)
fmax_idx = np.searchsorted(self.freqs, fmax, side="right")
tmin_idx = np.searchsorted(self.times, tmin)
tmax_idx = np.searchsorted(self.times, tmax, side="right")
freq_picks = np.arange(fmin_idx, fmax_idx)
time_picks = np.arange(tmin_idx, tmax_idx)
freq_axis = self._dims.index("freq")
time_axis = self._dims.index("time")
chan_axis = self._dims.index("channel")
# normally there's a risk of np.take reducing array dimension if there
# were only one channel or frequency selected, but `_picks_to_idx`
# and np.arange both always return arrays, so we're safe; the result
# will always have the same `ndim` as it started with.
data = (
self._data.take(picks, chan_axis)
.take(freq_picks, freq_axis)
.take(time_picks, time_axis)
)
out = [data]
if return_times:
times = self._raw_times[tmin_idx:tmax_idx]
out.append(times)
if return_freqs:
freqs = self._freqs[fmin_idx:fmax_idx]
out.append(freqs)
if not return_times and not return_freqs:
return out[0]
return tuple(out)
@verbose
def plot(
self,
picks=None,
*,
exclude=(),
tmin=None,
tmax=None,
fmin=0.0,
fmax=np.inf,
baseline=None,
mode="mean",
dB=False,
combine=None,
layout=None, # TODO deprecate? not used in orig implementation either
yscale="auto",
vmin=None,
vmax=None,
vlim=(None, None),
cnorm=None,
cmap=None,
colorbar=True,
title=None, # don't deprecate this one; has (useful) option title="auto"
mask=None,
mask_style=None,
mask_cmap="Greys",
mask_alpha=0.1,
axes=None,
show=True,
verbose=None,
):
"""Plot TFRs as two-dimensional time-frequency images.
Parameters
----------
%(picks_good_data)s
%(exclude_spectrum_plot)s
%(tmin_tmax_psd)s
%(fmin_fmax_tfr)s
%(baseline_rescale)s
How baseline is computed is determined by the ``mode`` parameter.
%(mode_tfr_plot)s
%(dB_spectrum_plot)s
%(combine_tfr_plot)s
.. versionchanged:: 1.3
Added support for ``callable``.
%(layout_spectrum_plot_topo)s
%(yscale_tfr_plot)s
.. versionadded:: 0.14.0
%(vmin_vmax_tfr_plot)s
%(vlim_tfr_plot)s
%(cnorm)s
.. versionadded:: 0.24
%(cmap_topomap)s
%(colorbar)s
%(title_tfr_plot)s
%(mask_tfr_plot)s
.. versionadded:: 0.16.0
%(mask_style_tfr_plot)s
.. versionadded:: 0.17
%(mask_cmap_tfr_plot)s
.. versionadded:: 0.17
%(mask_alpha_tfr_plot)s
.. versionadded:: 0.16.0
%(axes_tfr_plot)s
%(show)s
%(verbose)s
Returns
-------
figs : list of instances of matplotlib.figure.Figure
A list of figures containing the time-frequency power.
"""
# deprecations
vlim = _warn_deprecated_vmin_vmax(vlim, vmin, vmax)
# the rectangle selector plots topomaps, which needs all channels uncombined,
# so we keep a reference to that state here, and (because the topomap plotting
# function wants an AverageTFR) update it with `comment` and `nave` values in
# case we started out with a singleton EpochsTFR or RawTFR
initial_state = self.__getstate__()
initial_state.setdefault("comment", "")
initial_state.setdefault("nave", 1)
# `_picks_to_idx` also gets done inside `get_data()`` below, but we do it here
# because we need the indices later
idx_picks = _picks_to_idx(
self.info, picks, "data_or_ica", exclude=exclude, with_ref_meg=False
)
pick_names = np.array(self.ch_names)[idx_picks].tolist() # for titles
ch_types = self.get_channel_types(idx_picks)
# get data arrays
data, times, freqs = self.get_data(
picks=idx_picks, exclude=(), return_times=True, return_freqs=True
)
# pass tmin/tmax here ↓↓↓, not here ↑↑↑; we want to crop *after* baselining
data, times, freqs = _prep_data_for_plot(
data,
times,
freqs,
tmin=tmin,
tmax=tmax,
fmin=fmin,
fmax=fmax,
baseline=baseline,
mode=mode,
dB=dB,
verbose=verbose,
)
# shape
ch_axis = self._dims.index("channel")
freq_axis = self._dims.index("freq")
time_axis = self._dims.index("time")
want_shape = list(self.shape)
want_shape[ch_axis] = len(idx_picks) if combine is None else 1
want_shape[freq_axis] = len(freqs) # in case there was fmin/fmax cropping
want_shape[time_axis] = len(times) # in case there was tmin/tmax cropping
want_shape = tuple(want_shape)
# combine
combine_was_none = combine is None
combine = _make_combine_callable(
combine, axis=ch_axis, valid=("mean", "rms"), keepdims=True
)
try:
data = combine(data) # no need to copy; get_data() never returns a view
except Exception as e:
msg = (
"Something went wrong with the callable passed to 'combine'; see "
"traceback."
)
raise ValueError(msg) from e
# call succeeded, check type and shape
mismatch = False
if not isinstance(data, np.ndarray):
mismatch = "type"
extra = ""
elif data.shape not in (want_shape, want_shape[1:]):
mismatch = "shape"
extra = f" of shape {data.shape}"
if mismatch:
raise RuntimeError(
f"Wrong {mismatch} yielded by callable passed to 'combine'. Make sure "
"your function takes a single argument (an array of shape "
"(n_channels, n_freqs, n_times)) and returns an array of shape "
f"(n_freqs, n_times); yours yielded: {type(data)}{extra}."
)
# restore singleton collapsed axis (removed by user-provided callable):
# (n_freqs, n_times) → (1, n_freqs, n_times)
if data.shape == (len(freqs), len(times)):
data = data[np.newaxis]
assert data.shape == want_shape
# cmap handling. power may be negative depending on baseline strategy so set
# `norm` empirically — but only if user didn't set limits explicitly.
norm = False if vlim == (None, None) else data.min() >= 0.0
vmin, vmax = _setup_vmin_vmax(data, *vlim, norm=norm)
cmap = _setup_cmap(cmap, norm=norm)
# prepare figure(s)
if axes is None:
figs = [plt.figure(layout="constrained") for _ in range(data.shape[0])]
axes = [fig.add_subplot() for fig in figs]
elif isinstance(axes, plt.Axes):
figs = [axes.get_figure()]
axes = [axes]
elif isinstance(axes, np.ndarray): # allow plotting into a grid of axes
figs = [ax.get_figure() for ax in axes.flat]
elif hasattr(axes, "__iter__") and len(axes):
figs = [ax.get_figure() for ax in axes]
else:
raise ValueError(
f"axes must be None, Axes, or list/array of Axes, got {type(axes)}"
)
if len(axes) != data.shape[0]:
raise RuntimeError(
f"Mismatch between picked channels ({data.shape[0]}) and axes "
f"({len(axes)}); there must be one axes for each picked channel."
)
# check if we're being called from within plot_joint(). If so, get the
# `topomap_args` from the calling context and pass it to the onselect handler.
# (we need 2 `f_back` here because of the verbose decorator)
calling_frame = inspect.currentframe().f_back.f_back
source_plot_joint = calling_frame.f_code.co_name == "plot_joint"
topomap_args = (
dict()
if not source_plot_joint
else calling_frame.f_locals.get("topomap_args", dict())
)
# plot
for ix, _fig in enumerate(figs):
# restrict the onselect instance to the channel type of the picks used in
# the image plot
uniq_types = np.unique(ch_types)
ch_type = None if len(uniq_types) > 1 else uniq_types.item()
this_tfr = AverageTFR(inst=initial_state).pick(ch_type, verbose=verbose)
_onselect_callback = partial(
this_tfr._onselect,
picks=None, # already restricted the picks in `this_tfr`
exclude=(),
baseline=baseline,
mode=mode,
cmap=cmap,
source_plot_joint=source_plot_joint,
topomap_args=topomap_args,
)
# draw the image plot
_imshow_tfr(
ax=axes[ix],
tfr=data[[ix]],
ch_idx=0,
tmin=times[0],
tmax=times[-1],
vmin=vmin,
vmax=vmax,
onselect=_onselect_callback,
ylim=None,
freq=freqs,
x_label="Time (s)",
y_label="Frequency (Hz)",
colorbar=colorbar,
cmap=cmap,
yscale=yscale,
mask=mask,
mask_style=mask_style,
mask_cmap=mask_cmap,
mask_alpha=mask_alpha,
cnorm=cnorm,
)
# handle title. automatic title is:
# f"{Baselined} {power} ({ch_name})" or
# f"{Baselined} {power} ({combination} of {N} {ch_type}s)"
if title == "auto":
if combine_was_none: # one plot per channel
which_chs = pick_names[ix]
elif len(pick_names) == 1: # there was only one pick anyway
which_chs = pick_names[0]
else: # one plot for all chs combined
which_chs = _set_title_multiple_electrodes(
None, combine, pick_names, all_=True, ch_type=ch_type
)
_prefix = "Power" if baseline is None else "Baselined power"
_title = f"{_prefix} ({which_chs})"
else:
_title = title
_fig.suptitle(_title)
plt_show(show)
return figs
@verbose
def plot_joint(
self,
*,
timefreqs=None,
picks=None,
exclude=(),
combine="mean",
tmin=None,
tmax=None,
fmin=None,
fmax=None,
baseline=None,
mode="mean",
dB=False,
yscale="auto",
vmin=None,
vmax=None,
vlim=(None, None),
cnorm=None,
cmap=None,
colorbar=True,
title=None, # TODO consider deprecating this one, or adding an "auto" option
show=True,
topomap_args=None,
image_args=None,
verbose=None,
):
"""Plot TFRs as a two-dimensional image with topomap highlights.
Parameters
----------
%(timefreqs)s
%(picks_good_data)s
%(exclude_psd)s
Default is an empty :class:`tuple` which includes all channels.
%(combine_tfr_plot_joint)s
.. versionchanged:: 1.3
Added support for ``callable``.
%(tmin_tmax_psd)s
%(fmin_fmax_tfr)s
%(baseline_rescale)s
How baseline is computed is determined by the ``mode`` parameter.
%(mode_tfr_plot)s
%(dB_tfr_plot_topo)s
%(yscale_tfr_plot)s
%(vmin_vmax_tfr_plot)s
%(vlim_tfr_plot_joint)s
%(cnorm)s
%(cmap_tfr_plot_topo)s
%(colorbar_tfr_plot_joint)s
%(title_none)s
%(show)s
%(topomap_args)s
%(image_args)s
%(verbose)s
Returns
-------
fig : matplotlib.figure.Figure
The figure containing the topography.
Notes
-----
%(notes_timefreqs_tfr_plot_joint)s
.. versionadded:: 0.16.0
"""
from matplotlib import ticker
from matplotlib.patches import ConnectionPatch
# deprecations
vlim = _warn_deprecated_vmin_vmax(vlim, vmin, vmax)
# handle recursion
picks = _picks_to_idx(
self.info, picks, "data_or_ica", exclude=exclude, with_ref_meg=False
)
all_ch_types = np.array(self.get_channel_types())
uniq_ch_types = sorted(set(all_ch_types[picks]))
if len(uniq_ch_types) > 1:
msg = "Multiple channel types selected, returning one figure per type."
logger.info(msg)
figs = list()
for this_type in uniq_ch_types:
this_picks = np.intersect1d(
picks,
np.nonzero(np.isin(all_ch_types, this_type))[0],
assume_unique=True,
)
# TODO might be nice to not "copy first, then pick"; alternative might
# be to subset the data with `this_picks` and then construct the "copy"
# using __getstate__ and __setstate__
_tfr = self.copy().pick(this_picks)
figs.append(
_tfr.plot_joint(
timefreqs=timefreqs,
picks=None,
baseline=baseline,
mode=mode,
tmin=tmin,
tmax=tmax,
fmin=fmin,
fmax=fmax,
vlim=vlim,
cmap=cmap,
dB=dB,
colorbar=colorbar,
show=False,
title=title,
yscale=yscale,
combine=combine,
exclude=(),
topomap_args=topomap_args,
verbose=verbose,
)
)
return figs
else:
ch_type = uniq_ch_types[0]
# handle defaults
_validate_type(combine, ("str", "callable"), item_name="combine") # no `None`
image_args = dict() if image_args is None else image_args
topomap_args = dict() if topomap_args is None else topomap_args.copy()
# make sure if topomap_args["ch_type"] is set, it matches what is in `self.info`
topomap_args.setdefault("ch_type", ch_type)
if topomap_args["ch_type"] != ch_type:
raise ValueError(
f"topomap_args['ch_type'] is {topomap_args['ch_type']} which does not "
f"match the channel type present in the object ({ch_type})."
)
# some necessary defaults
topomap_args.setdefault("outlines", "head")
topomap_args.setdefault("contours", 6)
# don't pass these:
topomap_args.pop("axes", None)
topomap_args.pop("show", None)
topomap_args.pop("colorbar", None)
# get the time/freq limits of the image plot, to make sure requested annotation
# times/freqs are in range
_, times, freqs = self.get_data(
picks=picks,
exclude=(),
tmin=tmin,
tmax=tmax,
fmin=fmin,
fmax=fmax,
return_times=True,
return_freqs=True,
)
# validate requested annotation times and freqs
timefreqs = _get_timefreqs(self, timefreqs)
valid_timefreqs = dict()
while timefreqs:
(_time, _freq), (t_win, f_win) = timefreqs.popitem()
# convert to half-windows
t_win /= 2
f_win /= 2
# make sure the times / freqs are in-bounds
msg = (
"Requested {} exceeds the range of the data ({}). Choose different "
"`timefreqs`."
)
if (times > _time).all() or (times < _time).all():
_var = f"time point ({_time:0.3f} s)"
_range = f"{times[0]:0.3f} - {times[-1]:0.3f} s"
raise ValueError(msg.format(_var, _range))
elif (freqs > _freq).all() or (freqs < _freq).all():
_var = f"frequency ({_freq:0.1f} Hz)"
_range = f"{freqs[0]:0.1f} - {freqs[-1]:0.1f} Hz"
raise ValueError(msg.format(_var, _range))
# snap the times/freqs to the nearest point we have an estimate for, and
# store the validated points
if t_win == 0:
_time = times[np.argmin(np.abs(times - _time))]
if f_win == 0:
_freq = freqs[np.argmin(np.abs(freqs - _freq))]
valid_timefreqs[(_time, _freq)] = (t_win, f_win)
# prep data for topomaps (unlike image plot, must include all channels of the
# current ch_type). Don't pass tmin/tmax here (crop later after baselining)
topomap_picks = _picks_to_idx(self.info, ch_type)
data, times, freqs = self.get_data(
picks=topomap_picks, exclude=(), return_times=True, return_freqs=True
)
# merge grads before baselining (makes ERDS visible)
info = pick_info(self.info, sel=topomap_picks, copy=True)
data, pos = _merge_if_grads(
data=data,
info=info,
ch_type=ch_type,
sphere=topomap_args.get("sphere"),
combine=combine,
)
# loop over intended topomap locations, to find one vlim that works for all.
tf_array = np.array(list(valid_timefreqs)) # each row is [time, freq]
tf_array = tf_array[tf_array[:, 0].argsort()] # sort by time
_vmin, _vmax = (np.inf, -np.inf)
topomap_arrays = list()
topomap_titles = list()
for _time, _freq in tf_array:
# reduce data to the range of interest in the TF plane (i.e., finally crop)
t_win, f_win = valid_timefreqs[(_time, _freq)]
_tmin, _tmax = np.array([-1, 1]) * t_win + _time
_fmin, _fmax = np.array([-1, 1]) * f_win + _freq
_data, *_ = _prep_data_for_plot(
data,
times,
freqs,
tmin=_tmin,
tmax=_tmax,
fmin=_fmin,
fmax=_fmax,
baseline=baseline,
mode=mode,
verbose=verbose,
)
_data = _data.mean(axis=(-1, -2)) # avg over times and freqs
topomap_arrays.append(_data)
_vmin = min(_data.min(), _vmin)
_vmax = max(_data.max(), _vmax)
# construct topopmap subplot title
t_pm = "" if t_win == 0 else f" ± {t_win:0.2f}"
f_pm = "" if f_win == 0 else f" ± {f_win:0.1f}"
_title = f"{_time:0.2f}{t_pm} s,\n{_freq:0.1f}{f_pm} Hz"
topomap_titles.append(_title)
# handle cmap. Power may be negative depending on baseline strategy so set
# `norm` empirically. vmin/vmax will be handled separately within the `plot()`
# call for the image plot.
norm = np.min(topomap_arrays) >= 0.0
cmap = _setup_cmap(cmap, norm=norm)
topomap_args.setdefault("cmap", cmap[0]) # prevent interactive cbar
# finalize topomap vlims and compute contour locations.
# By passing `data=None` here ↓↓↓↓ we effectively assert vmin & vmax aren't None
_vlim = _setup_vmin_vmax(data=None, vmin=_vmin, vmax=_vmax, norm=norm)
topomap_args.setdefault("vlim", _vlim)
locator, topomap_args["contours"] = _set_contour_locator(
*topomap_args["vlim"], topomap_args["contours"]
)
# initialize figure and do the image plot. `self.plot()` needed to wait to be
# called until after `topomap_args` was fully populated --- we don't pass the
# dict through to `self.plot()` explicitly here, but we do "reach back" and get
# it if it's needed by the interactive rectangle selector.
fig, image_ax, topomap_axes = _prepare_joint_axes(len(valid_timefreqs))
fig = self.plot(
picks=picks,
exclude=(),
tmin=tmin,
tmax=tmax,
fmin=fmin,
fmax=fmax,
baseline=baseline,
mode=mode,
dB=dB,
combine=combine,
yscale=yscale,
vlim=vlim,
cnorm=cnorm,
cmap=cmap,
colorbar=False,
title=title,
# mask, mask_style, mask_cmap, mask_alpha
axes=image_ax,
show=False,
verbose=verbose,
**image_args,
)[0] # [0] because `.plot()` always returns a list
# now, actually plot the topomaps
for ax, title, _data in zip(topomap_axes, topomap_titles, topomap_arrays):
ax.set_title(title)
plot_topomap(_data, pos, axes=ax, show=False, **topomap_args)
# draw colorbar
if colorbar:
cbar = fig.colorbar(ax.images[0])
cbar.locator = ticker.MaxNLocator(nbins=5) if locator is None else locator
cbar.update_ticks()
# draw the connection lines between time-frequency image and topoplots
for (time_, freq_), topo_ax in zip(tf_array, topomap_axes):
con = ConnectionPatch(
xyA=[time_, freq_],
xyB=[0.5, 0],
coordsA="data",
coordsB="axes fraction",
axesA=image_ax,
axesB=topo_ax,
color="grey",
linestyle="-",
linewidth=1.5,
alpha=0.66,
zorder=1,
clip_on=False,
)
fig.add_artist(con)
plt_show(show)
return fig
@verbose
def plot_topo(
self,
picks=None,
baseline=None,
mode="mean",
tmin=None,
tmax=None,
fmin=None,
fmax=None,
vmin=None, # TODO deprecate in favor of `vlim` (needs helper func refactor)
vmax=None,
layout=None,
cmap="RdBu_r",
title=None, # don't deprecate; topo titles aren't standard (color, size, just.)
dB=False,
colorbar=True,
layout_scale=0.945,
show=True,
border="none",
fig_facecolor="k",
fig_background=None,
font_color="w",
yscale="auto",
verbose=None,
):
"""Plot a TFR image for each channel in a sensor layout arrangement.
Parameters
----------
%(picks_good_data)s
%(baseline_rescale)s
How baseline is computed is determined by the ``mode`` parameter.
%(mode_tfr_plot)s
%(tmin_tmax_psd)s
%(fmin_fmax_tfr)s
%(vmin_vmax_tfr_plot_topo)s
%(layout_spectrum_plot_topo)s
%(cmap_tfr_plot_topo)s
%(title_none)s
%(dB_tfr_plot_topo)s
%(colorbar)s
%(layout_scale)s
%(show)s
%(border_topo)s
%(fig_facecolor)s
%(fig_background)s
%(font_color)s
%(yscale_tfr_plot)s
%(verbose)s
Returns
-------
fig : matplotlib.figure.Figure
The figure containing the topography.
"""
# convenience vars
times = self.times.copy()
freqs = self.freqs
data = self.data
info = self.info
info, data = _prepare_picks(info, data, picks, axis=0)
del picks
# TODO this is the only remaining call to _preproc_tfr; should be refactored
# (to use _prep_data_for_plot?)
data, times, freqs, vmin, vmax = _preproc_tfr(
data,
times,
freqs,
tmin,
tmax,
fmin,
fmax,
mode,
baseline,
vmin,
vmax,
dB,
info["sfreq"],
)
if layout is None:
from mne import find_layout
layout = find_layout(self.info)
onselect_callback = partial(self._onselect, baseline=baseline, mode=mode)
click_fun = partial(
_imshow_tfr,
tfr=data,
freq=freqs,
yscale=yscale,
cmap=(cmap, True),
onselect=onselect_callback,
)
imshow = partial(
_imshow_tfr_unified,
tfr=data,
freq=freqs,
cmap=cmap,
onselect=onselect_callback,
)
fig = _plot_topo(
info=info,
times=times,
show_func=imshow,
click_func=click_fun,
layout=layout,
colorbar=colorbar,
vmin=vmin,
vmax=vmax,
cmap=cmap,
layout_scale=layout_scale,
title=title,
border=border,
x_label="Time (s)",
y_label="Frequency (Hz)",
fig_facecolor=fig_facecolor,
font_color=font_color,
unified=True,
img=True,
)
add_background_image(fig, fig_background)
plt_show(show)
return fig
@copy_function_doc_to_method_doc(plot_tfr_topomap)
def plot_topomap(
self,
tmin=None,
tmax=None,
fmin=0.0,
fmax=np.inf,
*,
ch_type=None,
baseline=None,
mode="mean",
sensors=True,
show_names=False,
mask=None,
mask_params=None,
contours=6,
outlines="head",
sphere=None,
image_interp=_INTERPOLATION_DEFAULT,
extrapolate=_EXTRAPOLATE_DEFAULT,
border=_BORDER_DEFAULT,
res=64,
size=2,
cmap=None,
vlim=(None, None),
cnorm=None,
colorbar=True,
cbar_fmt="%1.1e",
units=None,
axes=None,
show=True,
):
return plot_tfr_topomap(
self,
tmin=tmin,
tmax=tmax,
fmin=fmin,
fmax=fmax,
ch_type=ch_type,
baseline=baseline,
mode=mode,
sensors=sensors,
show_names=show_names,
mask=mask,
mask_params=mask_params,
contours=contours,
outlines=outlines,
sphere=sphere,
image_interp=image_interp,
extrapolate=extrapolate,
border=border,
res=res,
size=size,
cmap=cmap,
vlim=vlim,
cnorm=cnorm,
colorbar=colorbar,
cbar_fmt=cbar_fmt,
units=units,
axes=axes,
show=show,
)
@verbose
def save(self, fname, *, overwrite=False, verbose=None):
"""Save time-frequency data to disk (in HDF5 format).
Parameters
----------
fname : path-like
Path of file to save to, which should end with ``-tfr.h5`` or ``-tfr.hdf5``.
%(overwrite)s
%(verbose)s
See Also
--------
mne.time_frequency.read_tfrs
"""
_, write_hdf5 = _import_h5io_funcs()
check_fname(fname, "time-frequency object", (".h5", ".hdf5"))
fname = _check_fname(fname, overwrite=overwrite, verbose=verbose)
out = self.__getstate__()
if "metadata" in out:
out["metadata"] = _prepare_write_metadata(out["metadata"])
write_hdf5(fname, out, overwrite=overwrite, title="mnepython", slash="replace")
@verbose
def to_data_frame(
self,
picks=None,
index=None,
long_format=False,
time_format=None,
*,
verbose=None,
):
"""Export data in tabular structure as a pandas DataFrame.
Channels are converted to columns in the DataFrame. By default,
additional columns ``'time'``, ``'freq'``, ``'epoch'``, and
``'condition'`` (epoch event description) are added, unless ``index``
is not ``None`` (in which case the columns specified in ``index`` will
be used to form the DataFrame's index instead). ``'epoch'``, and
``'condition'`` are not supported for ``AverageTFR``.
Parameters
----------
%(picks_all)s
%(index_df_epo)s
Valid string values are ``'time'``, ``'freq'``, ``'epoch'``, and
``'condition'`` for ``EpochsTFR`` and ``'time'`` and ``'freq'``
for ``AverageTFR``.
Defaults to ``None``.
%(long_format_df_epo)s
%(time_format_df)s
.. versionadded:: 0.23
%(verbose)s
Returns
-------
%(df_return)s
"""
# check pandas once here, instead of in each private utils function
pd = _check_pandas_installed() # noqa
# arg checking
valid_index_args = ["time", "freq"]
if isinstance(self, EpochsTFR):
valid_index_args.extend(["epoch", "condition"])
valid_time_formats = ["ms", "timedelta"]
index = _check_pandas_index_arguments(index, valid_index_args)
time_format = _check_time_format(time_format, valid_time_formats)
# get data
picks = _picks_to_idx(self.info, picks, "all", exclude=())
data, times, freqs = self.get_data(picks, return_times=True, return_freqs=True)
axis = self._dims.index("channel")
if not isinstance(self, EpochsTFR):
data = data[np.newaxis] # add singleton "epochs" axis
axis += 1
n_epochs, n_picks, n_freqs, n_times = data.shape
# reshape to (epochs*freqs*times) x signals
data = np.moveaxis(data, axis, -1)
data = data.reshape(n_epochs * n_freqs * n_times, n_picks)
# prepare extra columns / multiindex
mindex = list()
times = _convert_times(times, time_format, self.info["meas_date"])
times = np.tile(times, n_epochs * n_freqs)
freqs = np.tile(np.repeat(freqs, n_times), n_epochs)
mindex.append(("time", times))
mindex.append(("freq", freqs))
if isinstance(self, EpochsTFR):
mindex.append(("epoch", np.repeat(self.selection, n_times * n_freqs)))
rev_event_id = {v: k for k, v in self.event_id.items()}
conditions = [rev_event_id[k] for k in self.events[:, 2]]
mindex.append(("condition", np.repeat(conditions, n_times * n_freqs)))
assert all(len(mdx) == len(mindex[0]) for mdx in mindex[1:])
# build DataFrame
if isinstance(self, EpochsTFR):
default_index = ["condition", "epoch", "freq", "time"]
else:
default_index = ["freq", "time"]
df = _build_data_frame(
self, data, picks, long_format, mindex, index, default_index=default_index
)
return df
@fill_doc
class AverageTFR(BaseTFR):
"""Data object for spectrotemporal representations of averaged data.
.. warning:: The preferred means of creating AverageTFR objects is via the
instance methods :meth:`mne.Epochs.compute_tfr` and
:meth:`mne.Evoked.compute_tfr`, or via
:meth:`mne.time_frequency.EpochsTFR.average`. Direct class
instantiation is discouraged.
Parameters
----------
%(info_not_none)s
.. deprecated:: 1.7
Pass an instance of :class:`~mne.Epochs` or :class:`~mne.Evoked` instead, or
use :class:`~mne.time_frequency.AverageTFRArray` which retains the old API.
data : ndarray, shape (n_channels, n_freqs, n_times)
The data.
.. deprecated:: 1.7
Pass an instance of :class:`~mne.Epochs` or :class:`~mne.Evoked` instead, or
use :class:`~mne.time_frequency.AverageTFRArray` which retains the old API.
times : ndarray, shape (n_times,)
The time values in seconds.
.. deprecated:: 1.7
Pass an instance of :class:`~mne.Epochs` or :class:`~mne.Evoked` instead and
(optionally) use ``tmin`` and ``tmax`` to restrict the time domain; or use
:class:`~mne.time_frequency.AverageTFRArray` which retains the old API.
freqs : ndarray, shape (n_freqs,)
The frequencies in Hz.
nave : int
The number of averaged TFRs.
.. deprecated:: 1.7
Pass an instance of :class:`~mne.Epochs` or :class:`~mne.Evoked` instead;
``nave`` will be inferred automatically. Or, use
:class:`~mne.time_frequency.AverageTFRArray` which retains the old API.
inst : instance of Evoked | instance of Epochs | dict
The data from which to compute the time-frequency representation. Passing a
:class:`dict` will create the AverageTFR using the ``__setstate__`` interface
and is not recommended for typical use cases.
%(method_tfr)s
%(freqs_tfr)s
%(tmin_tmax_psd)s
%(picks_good_data_noref)s
%(proj_psd)s
%(decim_tfr)s
%(comment_averagetfr)s
%(n_jobs)s
%(verbose)s
%(method_kw_tfr)s
Attributes
----------
%(baseline_tfr_attr)s
%(ch_names_tfr_attr)s
%(comment_averagetfr_attr)s
%(freqs_tfr_attr)s
%(info_not_none)s
%(method_tfr_attr)s
%(nave_tfr_attr)s
%(sfreq_tfr_attr)s
%(shape_tfr_attr)s
See Also
--------
RawTFR
EpochsTFR
AverageTFRArray
mne.Evoked.compute_tfr
mne.time_frequency.EpochsTFR.average
Notes
-----
The old API (prior to version 1.7) was::
AverageTFR(info, data, times, freqs, nave, comment=None, method=None)
That API is still available via :class:`~mne.time_frequency.AverageTFRArray` for
cases where the data are precomputed or do not originate from MNE-Python objects.
The preferred new API uses instance methods::
evoked.compute_tfr(method, freqs, ...)
epochs.compute_tfr(method, freqs, average=True, ...)
The new API also supports AverageTFR instantiation from a :class:`dict`, but this
is primarily for save/load and internal purposes, and wraps ``__setstate__``.
During the transition from the old to the new API, it may be expedient to use
:class:`~mne.time_frequency.AverageTFRArray` as a "quick-fix" approach to updating
scripts under active development.
References
----------
.. footbibliography::
"""
def __init__(
self,
info=None,
data=None,
times=None,
freqs=None,
nave=None,
*,
inst=None,
method=None,
tmin=None,
tmax=None,
picks=None,
proj=False,
decim=1,
comment=None,
n_jobs=None,
verbose=None,
**method_kw,
):
from ..epochs import BaseEpochs
from ..evoked import Evoked
from ._stockwell import _check_input_st, _compute_freqs_st
# dict is allowed for __setstate__ compatibility, and Epochs.compute_tfr() can
# return an AverageTFR depending on its parameters, so Epochs input is allowed
_validate_type(
inst, (BaseEpochs, Evoked, dict), "object passed to AverageTFR constructor"
)
# stockwell API is very different from multitaper/morlet
if method == "stockwell" and not isinstance(inst, dict):
if isinstance(freqs, str) and freqs == "auto":
fmin, fmax = None, None
elif len(freqs) == 2:
fmin, fmax = freqs
else:
raise ValueError(
"for Stockwell method, freqs must be a length-2 iterable "
f'or "auto", got {freqs}.'
)
method_kw.update(fmin=fmin, fmax=fmax)
# Compute freqs. We need a couple lines of code dupe here (also in
# BaseTFR.__init__) to get the subset of times to pass to _check_input_st()
_mask = _time_mask(inst.times, tmin, tmax, sfreq=inst.info["sfreq"])
_times = inst.times[_mask].copy()
_, default_nfft, _ = _check_input_st(_times, None)
n_fft = method_kw.get("n_fft", default_nfft)
*_, freqs = _compute_freqs_st(fmin, fmax, n_fft, inst.info["sfreq"])
# use Evoked.comment or str(Epochs.event_id) as the default comment...
if comment is None:
comment = getattr(inst, "comment", ",".join(getattr(inst, "event_id", "")))
# ...but don't overwrite if it's coming in with a comment already set
if isinstance(inst, dict):
inst.setdefault("comment", comment)
else:
self._comment = getattr(self, "_comment", comment)
super().__init__(
inst,
method,
freqs,
tmin=tmin,
tmax=tmax,
picks=picks,
proj=proj,
decim=decim,
n_jobs=n_jobs,
verbose=verbose,
**method_kw,
)
def __getstate__(self):
"""Prepare AverageTFR object for serialization."""
out = super().__getstate__()
out.update(nave=self.nave, comment=self.comment)
# NOTE: self._itc should never exist in the instance returned to the user; it
# is temporarily present in the output from the tfr_array_* function, and is
# split out into a separate AverageTFR object (and deleted from the object
# holding power estimates) before those objects are passed back to the user.
# The following lines are there because we make use of __getstate__ to achieve
# that splitting of objects.
if hasattr(self, "_itc"):
out.update(itc=self._itc)
return out
def __setstate__(self, state):
"""Unpack AverageTFR from serialized format."""
super().__setstate__(state)
self._comment = state.get("comment", "")
self._nave = state.get("nave", 1)
@property
def comment(self):
return self._comment
@comment.setter
def comment(self, comment):
self._comment = comment
@property
def nave(self):
return self._nave
@nave.setter
def nave(self, nave):
self._nave = nave
def _get_instance_data(self, time_mask):
# AverageTFRs can be constructed from Epochs data, so we triage shape here.
# Evoked data get a fake singleton "epoch" axis prepended
dim = slice(None) if _get_instance_type_string(self) == "Epochs" else np.newaxis
data = self.inst.get_data(picks=self._picks)[dim, :, time_mask]
self._nave = getattr(self.inst, "nave", data.shape[0])
return data
@fill_doc
class AverageTFRArray(AverageTFR):
"""Data object for *precomputed* spectrotemporal representations of averaged data.
Parameters
----------
%(info_not_none)s
%(data_tfr)s
%(times)s
%(freqs_tfr_array)s
nave : int
The number of averaged TFRs.
%(comment_averagetfr_attr)s
%(method_tfr_array)s
Attributes
----------
%(baseline_tfr_attr)s
%(ch_names_tfr_attr)s
%(comment_averagetfr_attr)s
%(freqs_tfr_attr)s
%(info_not_none)s
%(method_tfr_attr)s
%(nave_tfr_attr)s
%(sfreq_tfr_attr)s
%(shape_tfr_attr)s
See Also
--------
AverageTFR
EpochsTFRArray
mne.Epochs.compute_tfr
mne.Evoked.compute_tfr
"""
def __init__(
self, info, data, times, freqs, *, nave=None, comment=None, method=None
):
state = dict(info=info, data=data, times=times, freqs=freqs)
for name, optional in dict(nave=nave, comment=comment, method=method).items():
if optional is not None:
state[name] = optional
self.__setstate__(state)
@fill_doc
class EpochsTFR(BaseTFR, GetEpochsMixin):
"""Data object for spectrotemporal representations of epoched data.
.. important::
The preferred means of creating EpochsTFR objects from :class:`~mne.Epochs`
objects is via the instance method :meth:`~mne.Epochs.compute_tfr`.
To create an EpochsTFR object from pre-computed data (i.e., a NumPy array) use
:class:`~mne.time_frequency.EpochsTFRArray`.
Parameters
----------
%(info_not_none)s
.. deprecated:: 1.7
Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use
:class:`~mne.time_frequency.EpochsTFRArray` which retains the old API.
data : ndarray, shape (n_channels, n_freqs, n_times)
The data.
.. deprecated:: 1.7
Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use
:class:`~mne.time_frequency.EpochsTFRArray` which retains the old API.
times : ndarray, shape (n_times,)
The time values in seconds.
.. deprecated:: 1.7
Pass an instance of :class:`~mne.Epochs` as ``inst`` instead and
(optionally) use ``tmin`` and ``tmax`` to restrict the time domain; or use
:class:`~mne.time_frequency.EpochsTFRArray` which retains the old API.
%(freqs_tfr_epochs)s
inst : instance of Epochs
The data from which to compute the time-frequency representation.
%(method_tfr_epochs)s
%(comment_tfr_attr)s
.. deprecated:: 1.7
Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use
:class:`~mne.time_frequency.EpochsTFRArray` which retains the old API.
%(tmin_tmax_psd)s
%(picks_good_data_noref)s
%(proj_psd)s
%(decim_tfr)s
%(events_epochstfr)s
.. deprecated:: 1.7
Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use
:class:`~mne.time_frequency.EpochsTFRArray` which retains the old API.
%(event_id_epochstfr)s
.. deprecated:: 1.7
Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use
:class:`~mne.time_frequency.EpochsTFRArray` which retains the old API.
selection : array
List of indices of selected events (not dropped or ignored etc.). For
example, if the original event array had 4 events and the second event
has been dropped, this attribute would be np.array([0, 2, 3]).
.. deprecated:: 1.7
Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use
:class:`~mne.time_frequency.EpochsTFRArray` which retains the old API.
drop_log : tuple of tuple
A tuple of the same length as the event array used to initialize the
``EpochsTFR`` object. If the i-th original event is still part of the
selection, drop_log[i] will be an empty tuple; otherwise it will be
a tuple of the reasons the event is not longer in the selection, e.g.:
- ``'IGNORED'``
If it isn't part of the current subset defined by the user
- ``'NO_DATA'`` or ``'TOO_SHORT'``
If epoch didn't contain enough data names of channels that
exceeded the amplitude threshold
- ``'EQUALIZED_COUNTS'``
See :meth:`~mne.Epochs.equalize_event_counts`
- ``'USER'``
For user-defined reasons (see :meth:`~mne.Epochs.drop`).
.. deprecated:: 1.7
Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use
:class:`~mne.time_frequency.EpochsTFRArray` which retains the old API.
%(metadata_epochstfr)s
.. deprecated:: 1.7
Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use
:class:`~mne.time_frequency.EpochsTFRArray` which retains the old API.
%(n_jobs)s
%(verbose)s
%(method_kw_tfr)s
Attributes
----------
%(baseline_tfr_attr)s
%(ch_names_tfr_attr)s
%(comment_tfr_attr)s
%(drop_log)s
%(event_id_attr)s
%(events_attr)s
%(freqs_tfr_attr)s
%(info_not_none)s
%(metadata_attr)s
%(method_tfr_attr)s
%(selection_attr)s
%(sfreq_tfr_attr)s
%(shape_tfr_attr)s
See Also
--------
mne.Epochs.compute_tfr
RawTFR
AverageTFR
EpochsTFRArray
References
----------
.. footbibliography::
"""
def __init__(
self,
info=None,
data=None,
times=None,
freqs=None,
*,
inst=None,
method=None,
comment=None,
tmin=None,
tmax=None,
picks=None,
proj=False,
decim=1,
events=None,
event_id=None,
selection=None,
drop_log=None,
metadata=None,
n_jobs=None,
verbose=None,
**method_kw,
):
from ..epochs import BaseEpochs
# deprecations. TODO remove after 1.7 release
depr_params = dict(info=info, data=data, times=times, comment=comment)
bad_params = list()
for name, param in depr_params.items():
if param is not None:
bad_params.append(name)
if len(bad_params):
_s = _pl(bad_params)
is_are = _pl(bad_params, "is", "are")
bad_params_list = '", "'.join(bad_params)
warn(
f'Parameter{_s} "{bad_params_list}" {is_are} deprecated and will be '
"removed in version 1.8. For a quick fix, use ``EpochsTFRArray`` with "
"the same parameters. For a long-term fix, see the docstring notes.",
FutureWarning,
)
if inst is not None:
raise ValueError(
"Do not pass `inst` alongside deprecated params "
f'"{bad_params_list}"; see docstring of AverageTFR for guidance.'
)
# sensible defaults are created in __setstate__ so only pass these through
# if they're user-specified
optional = dict(
freqs=freqs,
method=method,
events=events,
event_id=event_id,
selection=selection,
drop_log=drop_log,
metadata=metadata,
)
optional_params = {
key: val for key, val in optional.items() if val is not None
}
inst = depr_params | optional_params
# end TODO ↑↑↑↑↑↑
# dict is allowed for __setstate__ compatibility
_validate_type(
inst, (BaseEpochs, dict), "object passed to EpochsTFR constructor", "Epochs"
)
super().__init__(
inst,
method,
freqs,
tmin=tmin,
tmax=tmax,
picks=picks,
proj=proj,
decim=decim,
n_jobs=n_jobs,
verbose=verbose,
**method_kw,
)
@fill_doc
def __getitem__(self, item):
"""Subselect epochs from an EpochsTFR.
Parameters
----------
%(item)s
Access options are the same as for :class:`~mne.Epochs` objects, see the
docstring Notes section of :meth:`mne.Epochs.__getitem__` for explanation.
Returns
-------
%(getitem_epochstfr_return)s
"""
return super().__getitem__(item)
def __getstate__(self):
"""Prepare EpochsTFR object for serialization."""
out = super().__getstate__()
out.update(
metadata=self._metadata,
drop_log=self.drop_log,
event_id=self.event_id,
events=self.events,
selection=self.selection,
raw_times=self._raw_times,
)
return out
def __setstate__(self, state):
"""Unpack EpochsTFR from serialized format."""
if state["data"].ndim != 4:
raise ValueError(f"EpochsTFR data should be 4D, got {state['data'].ndim}.")
super().__setstate__(state)
self._metadata = state.get("metadata", None)
n_epochs = self.shape[0]
n_times = self.shape[-1]
fake_samps = np.linspace(
n_times, n_times * (n_epochs + 1), n_epochs, dtype=int, endpoint=False
)
fake_events = np.dstack(
(fake_samps, np.zeros_like(fake_samps), np.ones_like(fake_samps))
).squeeze(axis=0)
self.events = state.get("events", _ensure_events(fake_events))
self.event_id = state.get("event_id", _check_event_id(None, self.events))
self.drop_log = state.get("drop_log", tuple())
self.selection = state.get("selection", np.arange(n_epochs))
self._bad_dropped = True # always true, need for `equalize_event_counts()`
def __next__(self, return_event_id=False):
"""Iterate over EpochsTFR objects.
NOTE: __iter__() and _stop_iter() are defined by the GetEpochs mixin.
Parameters
----------
return_event_id : bool
If ``True``, return both the EpochsTFR data and its associated ``event_id``.
Returns
-------
epoch : array of shape (n_channels, n_freqs, n_times)
The single-epoch time-frequency data.
event_id : int
The integer event id associated with the epoch. Only returned if
``return_event_id`` is ``True``.
"""
if self._current >= len(self._data):
self._stop_iter()
epoch = self._data[self._current]
event_id = self.events[self._current][-1]
self._current += 1
if return_event_id:
return epoch, event_id
return epoch
def _check_singleton(self):
"""Check if self contains only one Epoch, and return it as an AverageTFR."""
if self.shape[0] > 1:
calling_func = inspect.currentframe().f_back.f_code.co_name
raise NotImplementedError(
f"Cannot call {calling_func}() from EpochsTFR with multiple epochs; "
"please subselect a single epoch before plotting."
)
return list(self.iter_evoked())[0]
def _get_instance_data(self, time_mask):
return self.inst.get_data(picks=self._picks)[:, :, time_mask]
def _update_epoch_attributes(self):
# adjust dims and shape
if self.method != "stockwell": # stockwell consumes epochs dimension
self._dims = ("epoch",) + self._dims
self._shape = (len(self.inst),) + self._shape
# we need these for to_data_frame()
self.event_id = self.inst.event_id.copy()
self.events = self.inst.events.copy()
self.selection = self.inst.selection.copy()
# we need these for __getitem__()
self.drop_log = deepcopy(self.inst.drop_log)
self._metadata = self.inst.metadata
# we need this for compatibility with equalize_event_counts()
self._bad_dropped = True
def average(self, method="mean", *, dim="epochs", copy=False):
"""Aggregate the EpochsTFR across epochs, frequencies, or times.
Parameters
----------
method : "mean" | "median" | callable
How to aggregate the data across the given ``dim``. If callable,
must take a :class:`NumPy array<numpy.ndarray>` of shape
``(n_epochs, n_channels, n_freqs, n_times)`` and return an array
with one fewer dimensions (which dimension is collapsed depends on
the value of ``dim``). Default is ``"mean"``.
dim : "epochs" | "freqs" | "times"
The dimension along which to combine the data.
copy : bool
Whether to return a copy of the modified instance, or modify in place.
Ignored when ``dim="epochs"`` or ``"times"`` because those options return
different types (:class:`~mne.time_frequency.AverageTFR` and
:class:`~mne.time_frequency.EpochsSpectrum`, respectively).
Returns
-------
tfr : instance of EpochsTFR | AverageTFR | EpochsSpectrum
The aggregated TFR object.
Notes
-----
Passing in ``np.median`` is considered unsafe for complex data; pass
the string ``"median"`` instead to compute the *marginal* median
(i.e. the median of the real and imaginary components separately).
See discussion here:
https://github.com/scipy/scipy/pull/12676#issuecomment-783370228
"""
_check_option("dim", dim, ("epochs", "freqs", "times"))
axis = self._dims.index(dim[:-1]) # self._dims entries aren't plural
func = _check_combine(mode=method, axis=axis)
data = func(self.data)
n_epochs, n_channels, n_freqs, n_times = self.data.shape
freqs, times = self.freqs, self.times
if dim == "epochs":
expected_shape = self._data.shape[1:]
elif dim == "freqs":
expected_shape = (n_epochs, n_channels, n_times)
freqs = np.mean(self.freqs, keepdims=True)
elif dim == "times":
expected_shape = (n_epochs, n_channels, n_freqs)
times = np.mean(self.times, keepdims=True)
if data.shape != expected_shape:
raise RuntimeError(
"EpochsTFR.average() got a method that resulted in data of shape "
f"{data.shape}, but it should be {expected_shape}."
)
state = self.__getstate__()
# restore singleton freqs axis (not necessary for epochs/times: class changes)
if dim == "freqs":
data = np.expand_dims(data, axis=axis)
else:
state["dims"] = (*state["dims"][:axis], *state["dims"][axis + 1 :])
state["data"] = data
state["info"] = deepcopy(self.info)
state["freqs"] = freqs
state["times"] = times
if dim == "epochs":
state["inst_type_str"] = "Evoked"
state["nave"] = n_epochs
state["comment"] = f"{method} of {n_epochs} EpochsTFR{_pl(n_epochs)}"
out = AverageTFR(inst=state)
out._data_type = "Average Power"
return out
elif dim == "times":
return EpochsSpectrum(
state,
method=None,
fmin=None,
fmax=None,
tmin=None,
tmax=None,
picks=None,
exclude=None,
proj=None,
remove_dc=None,
n_jobs=None,
)
# ↓↓↓ these two are for dim == "freqs"
elif copy:
return EpochsTFR(inst=state, method=None, freqs=None)
else:
self._data = np.expand_dims(data, axis=axis)
self._freqs = freqs
return self
@verbose
def drop(self, indices, reason="USER", verbose=None):
"""Drop epochs based on indices or boolean mask.
.. note:: The indices refer to the current set of undropped epochs
rather than the complete set of dropped and undropped epochs.
They are therefore not necessarily consistent with any
external indices (e.g., behavioral logs). To drop epochs
based on external criteria, do not use the ``preload=True``
flag when constructing an Epochs object, and call this
method before calling the :meth:`mne.Epochs.drop_bad` or
:meth:`mne.Epochs.load_data` methods.
Parameters
----------
indices : array of int or bool
Set epochs to remove by specifying indices to remove or a boolean
mask to apply (where True values get removed). Events are
correspondingly modified.
reason : str
Reason for dropping the epochs ('ECG', 'timeout', 'blink' etc).
Default: 'USER'.
%(verbose)s
Returns
-------
epochs : instance of Epochs or EpochsTFR
The epochs with indices dropped. Operates in-place.
"""
from ..epochs import BaseEpochs
BaseEpochs.drop(self, indices=indices, reason=reason, verbose=verbose)
return self
def iter_evoked(self, copy=False):
"""Iterate over EpochsTFR to yield a sequence of AverageTFR objects.
The AverageTFR objects will each contain a single epoch (i.e., no averaging is
performed). This method resets the EpochTFR instance's iteration state to the
first epoch.
Parameters
----------
copy : bool
Whether to yield copies of the data and measurement info, or views/pointers.
"""
self.__iter__()
state = self.__getstate__()
state["inst_type_str"] = "Evoked"
state["dims"] = state["dims"][1:] # drop "epochs"
while True:
try:
data, event_id = self.__next__(return_event_id=True)
except StopIteration:
break
if copy:
state["info"] = deepcopy(self.info)
state["data"] = data.copy()
else:
state["data"] = data
state["nave"] = 1
yield AverageTFR(inst=state, method=None, freqs=None, comment=str(event_id))
@verbose
@copy_doc(BaseTFR.plot)
def plot(
self,
picks=None,
*,
exclude=(),
tmin=None,
tmax=None,
fmin=None,
fmax=None,
baseline=None,
mode="mean",
dB=False,
combine=None,
layout=None, # TODO deprecate; not used in orig implementation
yscale="auto",
vmin=None,
vmax=None,
vlim=(None, None),
cnorm=None,
cmap=None,
colorbar=True,
title=None, # don't deprecate this one; has (useful) option title="auto"
mask=None,
mask_style=None,
mask_cmap="Greys",
mask_alpha=0.1,
axes=None,
show=True,
verbose=None,
):
singleton_epoch = self._check_singleton()
return singleton_epoch.plot(
picks=picks,
exclude=exclude,
tmin=tmin,
tmax=tmax,
fmin=fmin,
fmax=fmax,
baseline=baseline,
mode=mode,
dB=dB,
combine=combine,
layout=layout,
yscale=yscale,
vmin=vmin,
vmax=vmax,
vlim=vlim,
cnorm=cnorm,
cmap=cmap,
colorbar=colorbar,
title=title,
mask=mask,
mask_style=mask_style,
mask_cmap=mask_cmap,
mask_alpha=mask_alpha,
axes=axes,
show=show,
verbose=verbose,
)
@verbose
@copy_doc(BaseTFR.plot_topo)
def plot_topo(
self,
picks=None,
baseline=None,
mode="mean",
tmin=None,
tmax=None,
fmin=None,
fmax=None,
vmin=None, # TODO deprecate in favor of `vlim` (needs helper func refactor)
vmax=None,
layout=None,
cmap=None,
title=None, # don't deprecate; topo titles aren't standard (color, size, just.)
dB=False,
colorbar=True,
layout_scale=0.945,
show=True,
border="none",
fig_facecolor="k",
fig_background=None,
font_color="w",
yscale="auto",
verbose=None,
):
singleton_epoch = self._check_singleton()
return singleton_epoch.plot_topo(
picks=picks,
baseline=baseline,
mode=mode,
tmin=tmin,
tmax=tmax,
fmin=fmin,
fmax=fmax,
vmin=vmin,
vmax=vmax,
layout=layout,
cmap=cmap,
title=title,
dB=dB,
colorbar=colorbar,
layout_scale=layout_scale,
show=show,
border=border,
fig_facecolor=fig_facecolor,
fig_background=fig_background,
font_color=font_color,
yscale=yscale,
verbose=verbose,
)
@verbose
@copy_doc(BaseTFR.plot_joint)
def plot_joint(
self,
*,
timefreqs=None,
picks=None,
exclude=(),
combine="mean",
tmin=None,
tmax=None,
fmin=None,
fmax=None,
baseline=None,
mode="mean",
dB=False,
yscale="auto",
vmin=None,
vmax=None,
vlim=(None, None),
cnorm=None,
cmap=None,
colorbar=True,
title=None,
show=True,
topomap_args=None,
image_args=None,
verbose=None,
):
singleton_epoch = self._check_singleton()
return singleton_epoch.plot_joint(
timefreqs=timefreqs,
picks=picks,
exclude=exclude,
combine=combine,
tmin=tmin,
tmax=tmax,
fmin=fmin,
fmax=fmax,
baseline=baseline,
mode=mode,
dB=dB,
yscale=yscale,
vmin=vmin,
vmax=vmax,
vlim=vlim,
cnorm=cnorm,
cmap=cmap,
colorbar=colorbar,
title=title,
show=show,
topomap_args=topomap_args,
image_args=image_args,
verbose=verbose,
)
@copy_doc(BaseTFR.plot_topomap)
def plot_topomap(
self,
tmin=None,
tmax=None,
fmin=0.0,
fmax=np.inf,
*,
ch_type=None,
baseline=None,
mode="mean",
sensors=True,
show_names=False,
mask=None,
mask_params=None,
contours=6,
outlines="head",
sphere=None,
image_interp=_INTERPOLATION_DEFAULT,
extrapolate=_EXTRAPOLATE_DEFAULT,
border=_BORDER_DEFAULT,
res=64,
size=2,
cmap=None,
vlim=(None, None),
cnorm=None,
colorbar=True,
cbar_fmt="%1.1e",
units=None,
axes=None,
show=True,
):
singleton_epoch = self._check_singleton()
return singleton_epoch.plot_topomap(
tmin=tmin,
tmax=tmax,
fmin=fmin,
fmax=fmax,
ch_type=ch_type,
baseline=baseline,
mode=mode,
sensors=sensors,
show_names=show_names,
mask=mask,
mask_params=mask_params,
contours=contours,
outlines=outlines,
sphere=sphere,
image_interp=image_interp,
extrapolate=extrapolate,
border=border,
res=res,
size=size,
cmap=cmap,
vlim=vlim,
cnorm=cnorm,
colorbar=colorbar,
cbar_fmt=cbar_fmt,
units=units,
axes=axes,
show=show,
)
@fill_doc
class EpochsTFRArray(EpochsTFR):
"""Data object for *precomputed* spectrotemporal representations of epoched data.
Parameters
----------
%(info_not_none)s
%(data_tfr)s
%(times)s
%(freqs_tfr_array)s
%(comment_tfr_attr)s
%(method_tfr_array)s
%(events_epochstfr)s
%(event_id_epochstfr)s
%(selection)s
%(drop_log)s
%(metadata_epochstfr)s
Attributes
----------
%(baseline_tfr_attr)s
%(ch_names_tfr_attr)s
%(comment_tfr_attr)s
%(drop_log)s
%(event_id_attr)s
%(events_attr)s
%(freqs_tfr_attr)s
%(info_not_none)s
%(metadata_attr)s
%(method_tfr_attr)s
%(selection_attr)s
%(sfreq_tfr_attr)s
%(shape_tfr_attr)s
See Also
--------
AverageTFR
mne.Epochs.compute_tfr
mne.Evoked.compute_tfr
"""
def __init__(
self,
info,
data,
times,
freqs,
*,
comment=None,
method=None,
events=None,
event_id=None,
selection=None,
drop_log=None,
metadata=None,
):
state = dict(info=info, data=data, times=times, freqs=freqs)
optional = dict(
comment=comment,
method=method,
events=events,
event_id=event_id,
selection=selection,
drop_log=drop_log,
metadata=metadata,
)
for name, value in optional.items():
if value is not None:
state[name] = value
self.__setstate__(state)
@fill_doc
class RawTFR(BaseTFR):
"""Data object for spectrotemporal representations of continuous data.
.. warning:: The preferred means of creating RawTFR objects from
:class:`~mne.io.Raw` objects is via the instance method
:meth:`~mne.io.Raw.compute_tfr`. Direct class instantiation
is not supported.
Parameters
----------
inst : instance of Raw
The data from which to compute the time-frequency representation.
%(method_tfr)s
%(freqs_tfr)s
%(tmin_tmax_psd)s
%(picks_good_data_noref)s
%(proj_psd)s
%(reject_by_annotation_tfr)s
%(decim_tfr)s
%(n_jobs)s
%(verbose)s
%(method_kw_tfr)s
Attributes
----------
ch_names : list
The channel names.
freqs : array
Frequencies at which the amplitude, power, or fourier coefficients
have been computed.
%(info_not_none)s
method : str
The method used to compute the spectra (``'morlet'``, ``'multitaper'``
or ``'stockwell'``).
See Also
--------
mne.io.Raw.compute_tfr
EpochsTFR
AverageTFR
References
----------
.. footbibliography::
"""
def __init__(
self,
inst,
method=None,
freqs=None,
*,
tmin=None,
tmax=None,
picks=None,
proj=False,
reject_by_annotation=False,
decim=1,
n_jobs=None,
verbose=None,
**method_kw,
):
from ..io import BaseRaw
# dict is allowed for __setstate__ compatibility
_validate_type(
inst, (BaseRaw, dict), "object passed to RawTFR constructor", "Raw"
)
super().__init__(
inst,
method,
freqs,
tmin=tmin,
tmax=tmax,
picks=picks,
proj=proj,
reject_by_annotation=reject_by_annotation,
decim=decim,
n_jobs=n_jobs,
verbose=verbose,
**method_kw,
)
def __getitem__(self, item):
"""Get RawTFR data.
Parameters
----------
item : int | slice | array-like
Indexing is similar to a :class:`NumPy array<numpy.ndarray>`; see
Notes.
Returns
-------
%(getitem_tfr_return)s
Notes
-----
The last axis is always time, the next-to-last axis is always
frequency, and the first axis is always channel. If
``method='multitaper'`` and ``output='complex'`` then the second axis
will be taper index.
Integer-, list-, and slice-based indexing is possible:
- ``raw_tfr[[0, 2]]`` gives the whole time-frequency plane for the
first and third channels.
- ``raw_tfr[..., :3, :]`` gives the first 3 frequency bins and all
times for all channels (and tapers, if present).
- ``raw_tfr[..., :100]`` gives the first 100 time samples in all
frequency bins for all channels (and tapers).
- ``raw_tfr[(4, 7)]`` is the same as ``raw_tfr[4, 7]``.
.. note::
Unlike :class:`~mne.io.Raw` objects (which returns a tuple of the
requested data values and the corresponding times), accessing
:class:`~mne.time_frequency.RawTFR` values via subscript does
**not** return the corresponding frequency bin values. If you need
them, use ``RawTFR.freqs[freq_indices]`` or
``RawTFR.get_data(..., return_freqs=True)``.
"""
from ..io import BaseRaw
self._parse_get_set_params = partial(BaseRaw._parse_get_set_params, self)
return BaseRaw._getitem(self, item, return_times=False)
def _get_instance_data(self, time_mask, reject_by_annotation):
start, stop = np.where(time_mask)[0][[0, -1]]
rba = "NaN" if reject_by_annotation else None
data = self.inst.get_data(
self._picks, start, stop + 1, reject_by_annotation=rba
)
# prepend a singleton "epochs" axis
return data[np.newaxis]
@fill_doc
class RawTFRArray(RawTFR):
"""Data object for *precomputed* spectrotemporal representations of continuous data.
Parameters
----------
%(info_not_none)s
%(data_tfr)s
%(times)s
%(freqs_tfr_array)s
%(method_tfr_array)s
Attributes
----------
%(baseline_tfr_attr)s
%(ch_names_tfr_attr)s
%(freqs_tfr_attr)s
%(info_not_none)s
%(method_tfr_attr)s
%(sfreq_tfr_attr)s
%(shape_tfr_attr)s
See Also
--------
RawTFR
mne.io.Raw.compute_tfr
EpochsTFRArray
AverageTFRArray
"""
def __init__(
self,
info,
data,
times,
freqs,
*,
method=None,
):
state = dict(info=info, data=data, times=times, freqs=freqs)
if method is not None:
state["method"] = method
self.__setstate__(state)
def combine_tfr(all_tfr, weights="nave"):
"""Merge AverageTFR data by weighted addition.
Create a new AverageTFR instance, using a combination of the supplied
instances as its data. By default, the mean (weighted by trials) is used.
Subtraction can be performed by passing negative weights (e.g., [1, -1]).
Data must have the same channels and the same time instants.
Parameters
----------
all_tfr : list of AverageTFR
The tfr datasets.
weights : list of float | str
The weights to apply to the data of each AverageTFR instance.
Can also be ``'nave'`` to weight according to tfr.nave,
or ``'equal'`` to use equal weighting (each weighted as ``1/N``).
Returns
-------
tfr : AverageTFR
The new TFR data.
Notes
-----
.. versionadded:: 0.11.0
"""
tfr = all_tfr[0].copy()
if isinstance(weights, str):
if weights not in ("nave", "equal"):
raise ValueError('Weights must be a list of float, or "nave" or "equal"')
if weights == "nave":
weights = np.array([e.nave for e in all_tfr], float)
weights /= weights.sum()
else: # == 'equal'
weights = [1.0 / len(all_tfr)] * len(all_tfr)
weights = np.array(weights, float)
if weights.ndim != 1 or weights.size != len(all_tfr):
raise ValueError("Weights must be the same size as all_tfr")
ch_names = tfr.ch_names
for t_ in all_tfr[1:]:
assert t_.ch_names == ch_names, ValueError(
f"{tfr} and {t_} do not contain the same channels"
)
assert np.max(np.abs(t_.times - tfr.times)) < 1e-7, ValueError(
f"{tfr} and {t_} do not contain the same time instants"
)
# use union of bad channels
bads = list(set(tfr.info["bads"]).union(*(t_.info["bads"] for t_ in all_tfr[1:])))
tfr.info["bads"] = bads
# XXX : should be refactored with combined_evoked function
tfr.data = sum(w * t_.data for w, t_ in zip(weights, all_tfr))
tfr.nave = max(int(1.0 / sum(w**2 / e.nave for w, e in zip(weights, all_tfr))), 1)
return tfr
# Utils
# ↓↓↓↓↓↓↓↓↓↓↓ this is still used in _stockwell.py
def _get_data(inst, return_itc):
"""Get data from Epochs or Evoked instance as epochs x ch x time."""
from ..epochs import BaseEpochs
from ..evoked import Evoked
if not isinstance(inst, (BaseEpochs, Evoked)):
raise TypeError("inst must be Epochs or Evoked")
if isinstance(inst, BaseEpochs):
data = inst.get_data(copy=False)
else:
if return_itc:
raise ValueError("return_itc must be False for evoked data")
data = inst.data[np.newaxis].copy()
return data
def _prepare_picks(info, data, picks, axis):
"""Prepare the picks."""
picks = _picks_to_idx(info, picks, exclude="bads")
info = pick_info(info, picks)
sl = [slice(None)] * data.ndim
sl[axis] = picks
data = data[tuple(sl)]
return info, data
def _centered(arr, newsize):
"""Aux Function to center data."""
# Return the center newsize portion of the array.
newsize = np.asarray(newsize)
currsize = np.array(arr.shape)
startind = (currsize - newsize) // 2
endind = startind + newsize
myslice = [slice(startind[k], endind[k]) for k in range(len(endind))]
return arr[tuple(myslice)]
def _preproc_tfr(
data,
times,
freqs,
tmin,
tmax,
fmin,
fmax,
mode,
baseline,
vmin,
vmax,
dB,
sfreq,
copy=None,
):
"""Aux Function to prepare tfr computation."""
if copy is None:
copy = baseline is not None
data = rescale(data, times, baseline, mode, copy=copy)
if np.iscomplexobj(data):
# complex amplitude → real power (for plotting); if data are
# real-valued they should already be power
data = (data * data.conj()).real
# crop time
itmin, itmax = None, None
idx = np.where(_time_mask(times, tmin, tmax, sfreq=sfreq))[0]
if tmin is not None:
itmin = idx[0]
if tmax is not None:
itmax = idx[-1] + 1
times = times[itmin:itmax]
# crop freqs
ifmin, ifmax = None, None
idx = np.where(_time_mask(freqs, fmin, fmax, sfreq=sfreq))[0]
if fmin is not None:
ifmin = idx[0]
if fmax is not None:
ifmax = idx[-1] + 1
freqs = freqs[ifmin:ifmax]
# crop data
data = data[:, ifmin:ifmax, itmin:itmax]
if dB:
data = 10 * np.log10(data)
vmin, vmax = _setup_vmin_vmax(data, vmin, vmax)
return data, times, freqs, vmin, vmax
def _ensure_slice(decim):
"""Aux function checking the decim parameter."""
_validate_type(decim, ("int-like", slice), "decim")
if not isinstance(decim, slice):
decim = slice(None, None, int(decim))
# ensure that we can actually use `decim.step`
if decim.step is None:
decim = slice(decim.start, decim.stop, 1)
return decim
# i/o
@verbose
def write_tfrs(fname, tfr, overwrite=False, *, verbose=None):
"""Write a TFR dataset to hdf5.
Parameters
----------
fname : path-like
The file name, which should end with ``-tfr.h5``.
tfr : RawTFR | EpochsTFR | AverageTFR | list of RawTFR | list of EpochsTFR | list of AverageTFR
The (list of) TFR object(s) to save in one file. If ``tfr.comment`` is ``None``,
a sequential numeric string name will be generated on the fly, based on the
order in which the TFR objects are passed. This can be used to selectively load
single TFR objects from the file later.
%(overwrite)s
%(verbose)s
See Also
--------
read_tfrs
Notes
-----
.. versionadded:: 0.9.0
""" # noqa E501
_, write_hdf5 = _import_h5io_funcs()
out = []
if not isinstance(tfr, (list, tuple)):
tfr = [tfr]
for ii, tfr_ in enumerate(tfr):
comment = ii if getattr(tfr_, "comment", None) is None else tfr_.comment
state = tfr_.__getstate__()
if "metadata" in state:
state["metadata"] = _prepare_write_metadata(state["metadata"])
out.append((comment, state))
write_hdf5(fname, out, overwrite=overwrite, title="mnepython", slash="replace")
@verbose
def read_tfrs(fname, condition=None, *, verbose=None):
"""Load a TFR object from disk.
Parameters
----------
fname : path-like
Path to a TFR file in HDF5 format, which should end with ``-tfr.h5`` or
``-tfr.hdf5``.
condition : int or str | list of int or str | None
The condition to load. If ``None``, all conditions will be returned.
Defaults to ``None``.
%(verbose)s
Returns
-------
tfr : RawTFR | EpochsTFR | AverageTFR | list of RawTFR | list of EpochsTFR | list of AverageTFR
The loaded time-frequency object.
See Also
--------
mne.time_frequency.RawTFR.save
mne.time_frequency.EpochsTFR.save
mne.time_frequency.AverageTFR.save
write_tfrs
Notes
-----
.. versionadded:: 0.9.0
""" # noqa E501
read_hdf5, _ = _import_h5io_funcs()
fname = _check_fname(fname=fname, overwrite="read", must_exist=False)
valid_fnames = tuple(
f"{sep}tfr.{ext}" for sep in ("-", "_") for ext in ("h5", "hdf5")
)
check_fname(fname, "tfr", valid_fnames)
logger.info(f"Reading {fname} ...")
hdf5_dict = read_hdf5(fname, title="mnepython", slash="replace")
# single TFR from TFR.save()
if "inst_type_str" in hdf5_dict:
if "epoch" in hdf5_dict["dims"]:
Klass = EpochsTFR
elif "nave" in hdf5_dict:
Klass = AverageTFR
else:
Klass = RawTFR
out = Klass(inst=hdf5_dict)
if getattr(out, "metadata", None) is not None:
out.metadata = _prepare_read_metadata(out.metadata)
return out
# maybe multiple TFRs from write_tfrs()
return _read_multiple_tfrs(hdf5_dict, condition=condition, verbose=verbose)
@verbose
def _read_multiple_tfrs(tfr_data, condition=None, *, verbose=None):
"""Read (possibly multiple) TFR datasets from an h5 file written by write_tfrs()."""
out = list()
keys = list()
# tfr_data is a list of (comment, tfr_dict) tuples
for key, tfr in tfr_data:
keys.append(str(key)) # auto-assigned keys are ints
is_epochs = tfr["data"].ndim == 4
is_average = "nave" in tfr
if condition is not None:
if not is_average:
raise NotImplementedError(
"condition is only supported when reading AverageTFRs."
)
if key != condition:
continue
tfr = dict(tfr)
tfr["info"] = Info(tfr["info"])
tfr["info"]._check_consistency()
if "metadata" in tfr:
tfr["metadata"] = _prepare_read_metadata(tfr["metadata"])
# additional keys needed for TFR __setstate__
defaults = dict(baseline=None, data_type="Power Estimates")
if is_epochs:
Klass = EpochsTFR
defaults.update(
inst_type_str="Epochs", dims=("epoch", "channel", "freq", "time")
)
elif is_average:
Klass = AverageTFR
defaults.update(inst_type_str="Evoked", dims=("channel", "freq", "time"))
else:
Klass = RawTFR
defaults.update(inst_type_str="Raw", dims=("channel", "freq", "time"))
out.append(Klass(inst=defaults | tfr))
if len(out) == 0:
raise ValueError(
f'Cannot find condition "{condition}" in this file. '
f'The file contains conditions {", ".join(keys)}'
)
if len(out) == 1:
out = out[0]
return out
def _get_timefreqs(tfr, timefreqs):
"""Find and/or setup timefreqs for `tfr.plot_joint`."""
# Input check
timefreq_error_msg = (
"Supplied `timefreqs` are somehow malformed. Please supply None, "
"a list of tuple pairs, or a dict of such tuple pairs, not {}"
)
if isinstance(timefreqs, dict):
for k, v in timefreqs.items():
for item in (k, v):
if len(item) != 2 or any(not _is_numeric(n) for n in item):
raise ValueError(timefreq_error_msg, item)
elif timefreqs is not None:
if not hasattr(timefreqs, "__len__"):
raise ValueError(timefreq_error_msg.format(timefreqs))
if len(timefreqs) == 2 and all(_is_numeric(v) for v in timefreqs):
timefreqs = [tuple(timefreqs)] # stick a pair of numbers in a list
else:
for item in timefreqs:
if (
hasattr(item, "__len__")
and len(item) == 2
and all(_is_numeric(n) for n in item)
):
pass
else:
raise ValueError(timefreq_error_msg.format(item))
# If None, automatic identification of max peak
else:
order = max((1, tfr.data.shape[2] // 30))
peaks_idx = argrelmax(tfr.data, order=order, axis=2)
if peaks_idx[0].size == 0:
_, p_t, p_f = np.unravel_index(tfr.data.argmax(), tfr.data.shape)
timefreqs = [(tfr.times[p_t], tfr.freqs[p_f])]
else:
peaks = [tfr.data[0, f, t] for f, t in zip(peaks_idx[1], peaks_idx[2])]
peakmax_idx = np.argmax(peaks)
peakmax_time = tfr.times[peaks_idx[2][peakmax_idx]]
peakmax_freq = tfr.freqs[peaks_idx[1][peakmax_idx]]
timefreqs = [(peakmax_time, peakmax_freq)]
timefreqs = {
tuple(k): np.asarray(timefreqs[k])
if isinstance(timefreqs, dict)
else np.array([0, 0])
for k in timefreqs
}
return timefreqs
def _check_tfr_complex(tfr, reason="source space estimation"):
"""Check that time-frequency epochs or average data is complex."""
if not np.iscomplexobj(tfr.data):
raise RuntimeError(f"Time-frequency data must be complex for {reason}")
def _merge_if_grads(data, info, ch_type, sphere, combine=None):
if ch_type == "grad":
grad_picks = _pair_grad_sensors(info, topomap_coords=False)
pos = _find_topomap_coords(info, picks=grad_picks[::2], sphere=sphere)
grad_method = combine if isinstance(combine, str) else "rms"
data, _ = _merge_ch_data(data[grad_picks], ch_type, [], method=grad_method)
else:
pos, _ = _get_pos_outlines(info, picks=ch_type, sphere=sphere)
return data, pos
@verbose
def _prep_data_for_plot(
data,
times,
freqs,
*,
tmin=None,
tmax=None,
fmin=None,
fmax=None,
baseline=None,
mode=None,
dB=False,
verbose=None,
):
# baseline
copy = baseline is not None
data = rescale(data, times, baseline, mode, copy=copy, verbose=verbose)
# crop times
time_mask = np.nonzero(_time_mask(times, tmin, tmax))[0]
times = times[time_mask]
# crop freqs
freq_mask = np.nonzero(_time_mask(freqs, fmin, fmax))[0]
freqs = freqs[freq_mask]
# crop data
data = data[..., freq_mask, :][..., time_mask]
# complex amplitude → real power; real-valued data is already power (or ITC)
if np.iscomplexobj(data):
data = (data * data.conj()).real
if dB:
data = 10 * np.log10(data)
return data, times, freqs
def _warn_deprecated_vmin_vmax(vlim, vmin, vmax):
if vmin is not None or vmax is not None:
warning = "Parameters `vmin` and `vmax` are deprecated, use `vlim` instead."
if vlim[0] is None and vlim[1] is None:
vlim = (vmin, vmax)
else:
warning += (
" You've also provided a (non-default) value for `vlim`, "
"so `vmin` and `vmax` will be ignored."
)
warn(warning, FutureWarning)
return vlim