1717 lines
54 KiB
Python
1717 lines
54 KiB
Python
"""Container classes for spectral data."""
|
||
|
||
# Authors: The MNE-Python contributors.
|
||
# License: BSD-3-Clause
|
||
# Copyright the MNE-Python contributors.
|
||
|
||
from copy import deepcopy
|
||
from functools import partial
|
||
from inspect import signature
|
||
|
||
import numpy as np
|
||
|
||
from .._fiff.meas_info import ContainsMixin, Info
|
||
from .._fiff.pick import _pick_data_channels, _picks_to_idx, pick_info
|
||
from ..channels.channels import UpdateChannelsMixin
|
||
from ..channels.layout import _merge_ch_data, find_layout
|
||
from ..defaults import (
|
||
_BORDER_DEFAULT,
|
||
_EXTRAPOLATE_DEFAULT,
|
||
_INTERPOLATION_DEFAULT,
|
||
_handle_default,
|
||
)
|
||
from ..html_templates import _get_html_template
|
||
from ..utils import (
|
||
GetEpochsMixin,
|
||
_build_data_frame,
|
||
_check_method_kwargs,
|
||
_check_pandas_index_arguments,
|
||
_check_pandas_installed,
|
||
_check_sphere,
|
||
_time_mask,
|
||
_validate_type,
|
||
fill_doc,
|
||
legacy,
|
||
logger,
|
||
object_diff,
|
||
repr_html,
|
||
verbose,
|
||
warn,
|
||
)
|
||
from ..utils.check import (
|
||
_check_fname,
|
||
_check_option,
|
||
_import_h5io_funcs,
|
||
_is_numeric,
|
||
check_fname,
|
||
)
|
||
from ..utils.misc import _pl
|
||
from ..utils.spectrum import _get_instance_type_string, _split_psd_kwargs
|
||
from ..viz.topo import _plot_timeseries, _plot_timeseries_unified, _plot_topo
|
||
from ..viz.topomap import _make_head_outlines, _prepare_topomap_plot, plot_psds_topomap
|
||
from ..viz.utils import (
|
||
_format_units_psd,
|
||
_get_plot_ch_type,
|
||
_make_combine_callable,
|
||
_plot_psd,
|
||
_prepare_sensor_names,
|
||
plt_show,
|
||
)
|
||
from .multitaper import _psd_from_mt, psd_array_multitaper
|
||
from .psd import _check_nfft, psd_array_welch
|
||
|
||
|
||
class SpectrumMixin:
|
||
"""Mixin providing spectral plotting methods to sensor-space containers."""
|
||
|
||
@legacy(alt=".compute_psd().plot()")
|
||
@verbose
|
||
def plot_psd(
|
||
self,
|
||
fmin=0,
|
||
fmax=np.inf,
|
||
tmin=None,
|
||
tmax=None,
|
||
picks=None,
|
||
proj=False,
|
||
reject_by_annotation=True,
|
||
*,
|
||
method="auto",
|
||
average=False,
|
||
dB=True,
|
||
estimate="power",
|
||
xscale="linear",
|
||
area_mode="std",
|
||
area_alpha=0.33,
|
||
color="black",
|
||
line_alpha=None,
|
||
spatial_colors=True,
|
||
sphere=None,
|
||
exclude="bads",
|
||
ax=None,
|
||
show=True,
|
||
n_jobs=1,
|
||
verbose=None,
|
||
**method_kw,
|
||
):
|
||
"""%(plot_psd_doc)s.
|
||
|
||
Parameters
|
||
----------
|
||
%(fmin_fmax_psd)s
|
||
%(tmin_tmax_psd)s
|
||
%(picks_good_data_noref)s
|
||
%(proj_psd)s
|
||
%(reject_by_annotation_psd)s
|
||
%(method_plot_psd_auto)s
|
||
%(average_plot_psd)s
|
||
%(dB_plot_psd)s
|
||
%(estimate_plot_psd)s
|
||
%(xscale_plot_psd)s
|
||
%(area_mode_plot_psd)s
|
||
%(area_alpha_plot_psd)s
|
||
%(color_plot_psd)s
|
||
%(line_alpha_plot_psd)s
|
||
%(spatial_colors_psd)s
|
||
%(sphere_topomap_auto)s
|
||
|
||
.. versionadded:: 0.22.0
|
||
exclude : list of str | 'bads'
|
||
Channels names to exclude from being shown. If 'bads', the bad
|
||
channels are excluded. Pass an empty list to plot all channels
|
||
(including channels marked "bad", if any).
|
||
|
||
.. versionadded:: 0.24.0
|
||
%(ax_plot_psd)s
|
||
%(show)s
|
||
%(n_jobs)s
|
||
%(verbose)s
|
||
%(method_kw_psd)s
|
||
|
||
Returns
|
||
-------
|
||
fig : instance of Figure
|
||
Figure with frequency spectra of the data channels.
|
||
|
||
Notes
|
||
-----
|
||
%(notes_plot_psd_meth)s
|
||
"""
|
||
init_kw, plot_kw = _split_psd_kwargs(plot_fun=Spectrum.plot)
|
||
return self.compute_psd(**init_kw).plot(**plot_kw)
|
||
|
||
@legacy(alt=".compute_psd().plot_topo()")
|
||
@verbose
|
||
def plot_psd_topo(
|
||
self,
|
||
tmin=None,
|
||
tmax=None,
|
||
fmin=0,
|
||
fmax=100,
|
||
proj=False,
|
||
*,
|
||
method="auto",
|
||
dB=True,
|
||
layout=None,
|
||
color="w",
|
||
fig_facecolor="k",
|
||
axis_facecolor="k",
|
||
axes=None,
|
||
block=False,
|
||
show=True,
|
||
n_jobs=None,
|
||
verbose=None,
|
||
**method_kw,
|
||
):
|
||
"""Plot power spectral density, separately for each channel.
|
||
|
||
Parameters
|
||
----------
|
||
%(tmin_tmax_psd)s
|
||
%(fmin_fmax_psd_topo)s
|
||
%(proj_psd)s
|
||
%(method_plot_psd_auto)s
|
||
%(dB_spectrum_plot_topo)s
|
||
%(layout_spectrum_plot_topo)s
|
||
%(color_spectrum_plot_topo)s
|
||
%(fig_facecolor)s
|
||
%(axis_facecolor)s
|
||
%(axes_spectrum_plot_topo)s
|
||
%(block)s
|
||
%(show)s
|
||
%(n_jobs)s
|
||
%(verbose)s
|
||
%(method_kw_psd)s Defaults to ``dict(n_fft=2048)``.
|
||
|
||
Returns
|
||
-------
|
||
fig : instance of matplotlib.figure.Figure
|
||
Figure distributing one image per channel across sensor topography.
|
||
"""
|
||
init_kw, plot_kw = _split_psd_kwargs(plot_fun=Spectrum.plot_topo)
|
||
return self.compute_psd(**init_kw).plot_topo(**plot_kw)
|
||
|
||
@legacy(alt=".compute_psd().plot_topomap()")
|
||
@verbose
|
||
def plot_psd_topomap(
|
||
self,
|
||
bands=None,
|
||
tmin=None,
|
||
tmax=None,
|
||
ch_type=None,
|
||
*,
|
||
proj=False,
|
||
method="auto",
|
||
normalize=False,
|
||
agg_fun=None,
|
||
dB=False,
|
||
sensors=True,
|
||
show_names=False,
|
||
mask=None,
|
||
mask_params=None,
|
||
contours=0,
|
||
outlines="head",
|
||
sphere=None,
|
||
image_interp=_INTERPOLATION_DEFAULT,
|
||
extrapolate=_EXTRAPOLATE_DEFAULT,
|
||
border=_BORDER_DEFAULT,
|
||
res=64,
|
||
size=1,
|
||
cmap=None,
|
||
vlim=(None, None),
|
||
cnorm=None,
|
||
colorbar=True,
|
||
cbar_fmt="auto",
|
||
units=None,
|
||
axes=None,
|
||
show=True,
|
||
n_jobs=None,
|
||
verbose=None,
|
||
**method_kw,
|
||
):
|
||
"""Plot scalp topography of PSD for chosen frequency bands.
|
||
|
||
Parameters
|
||
----------
|
||
%(bands_psd_topo)s
|
||
%(tmin_tmax_psd)s
|
||
%(ch_type_topomap_psd)s
|
||
%(proj_psd)s
|
||
%(method_plot_psd_auto)s
|
||
%(normalize_psd_topo)s
|
||
%(agg_fun_psd_topo)s
|
||
%(dB_plot_topomap)s
|
||
%(sensors_topomap)s
|
||
%(show_names_topomap)s
|
||
%(mask_evoked_topomap)s
|
||
%(mask_params_topomap)s
|
||
%(contours_topomap)s
|
||
%(outlines_topomap)s
|
||
%(sphere_topomap_auto)s
|
||
%(image_interp_topomap)s
|
||
%(extrapolate_topomap)s
|
||
%(border_topomap)s
|
||
%(res_topomap)s
|
||
%(size_topomap)s
|
||
%(cmap_topomap)s
|
||
%(vlim_plot_topomap_psd)s
|
||
%(cnorm)s
|
||
|
||
.. versionadded:: 1.2
|
||
%(colorbar_topomap)s
|
||
%(cbar_fmt_topomap_psd)s
|
||
%(units_topomap)s
|
||
%(axes_spectrum_plot_topomap)s
|
||
%(show)s
|
||
%(n_jobs)s
|
||
%(verbose)s
|
||
%(method_kw_psd)s
|
||
|
||
Returns
|
||
-------
|
||
fig : instance of Figure
|
||
Figure showing one scalp topography per frequency band.
|
||
"""
|
||
init_kw, plot_kw = _split_psd_kwargs(plot_fun=Spectrum.plot_topomap)
|
||
return self.compute_psd(**init_kw).plot_topomap(**plot_kw)
|
||
|
||
def _set_legacy_nfft_default(self, tmin, tmax, method, method_kw):
|
||
"""Update method_kw with legacy n_fft default for plot_psd[_topo]().
|
||
|
||
This method returns ``None`` and has a side effect of (maybe) updating
|
||
the ``method_kw`` dict.
|
||
"""
|
||
if method == "welch" and method_kw.get("n_fft") is None:
|
||
tm = _time_mask(self.times, tmin, tmax, sfreq=self.info["sfreq"])
|
||
method_kw["n_fft"] = min(np.sum(tm), 2048)
|
||
|
||
|
||
class BaseSpectrum(ContainsMixin, UpdateChannelsMixin):
|
||
"""Base class for Spectrum and EpochsSpectrum."""
|
||
|
||
def __init__(
|
||
self,
|
||
inst,
|
||
method,
|
||
fmin,
|
||
fmax,
|
||
tmin,
|
||
tmax,
|
||
picks,
|
||
exclude,
|
||
proj,
|
||
remove_dc,
|
||
*,
|
||
n_jobs,
|
||
verbose=None,
|
||
**method_kw,
|
||
):
|
||
# arg checking
|
||
self._sfreq = inst.info["sfreq"]
|
||
if np.isfinite(fmax) and (fmax > self.sfreq / 2):
|
||
raise ValueError(
|
||
f"Requested fmax ({fmax} Hz) must not exceed ½ the sampling "
|
||
f'frequency of the data ({0.5 * inst.info["sfreq"]} Hz).'
|
||
)
|
||
# method
|
||
self._inst_type = type(inst)
|
||
method = _validate_method(method, _get_instance_type_string(self))
|
||
psd_funcs = dict(welch=psd_array_welch, multitaper=psd_array_multitaper)
|
||
# triage method and kwargs. partial() doesn't check validity of kwargs,
|
||
# so we do it manually to save compute time if any are invalid.
|
||
psd_funcs = dict(welch=psd_array_welch, multitaper=psd_array_multitaper)
|
||
_check_method_kwargs(psd_funcs[method], method_kw, msg=f'PSD method "{method}"')
|
||
self._psd_func = partial(psd_funcs[method], remove_dc=remove_dc, **method_kw)
|
||
|
||
# apply proj if desired
|
||
if proj:
|
||
inst = inst.copy().apply_proj()
|
||
self.inst = inst
|
||
|
||
# prep times and picks
|
||
self._time_mask = _time_mask(inst.times, tmin, tmax, sfreq=self.sfreq)
|
||
self._picks = _picks_to_idx(
|
||
inst.info, picks, "data", exclude, with_ref_meg=False
|
||
)
|
||
|
||
# add the info object. bads and non-data channels were dropped by
|
||
# _picks_to_idx() so we update the info accordingly:
|
||
self.info = pick_info(inst.info, sel=self._picks, copy=True)
|
||
|
||
# assign some attributes
|
||
self.preload = True # needed for __getitem__, never False
|
||
self._method = method
|
||
# self._dims may also get updated by child classes
|
||
self._dims = (
|
||
"channel",
|
||
"freq",
|
||
)
|
||
if method_kw.get("average", "") in (None, False):
|
||
self._dims += ("segment",)
|
||
if self._returns_complex_tapers(**method_kw):
|
||
self._dims = self._dims[:-1] + ("taper",) + self._dims[-1:]
|
||
# record data type (for repr and html_repr)
|
||
self._data_type = (
|
||
"Fourier Coefficients"
|
||
if method_kw.get("output") == "complex"
|
||
else "Power Spectrum"
|
||
)
|
||
# set nave (child constructor overrides this for Evoked input)
|
||
self._nave = None
|
||
|
||
def __eq__(self, other):
|
||
"""Test equivalence of two Spectrum instances."""
|
||
return object_diff(vars(self), vars(other)) == ""
|
||
|
||
def __getstate__(self):
|
||
"""Prepare object for serialization."""
|
||
inst_type_str = _get_instance_type_string(self)
|
||
out = dict(
|
||
method=self.method,
|
||
data=self._data,
|
||
sfreq=self.sfreq,
|
||
dims=self._dims,
|
||
freqs=self.freqs,
|
||
inst_type_str=inst_type_str,
|
||
data_type=self._data_type,
|
||
info=self.info,
|
||
nave=self.nave,
|
||
weights=self.weights,
|
||
)
|
||
return out
|
||
|
||
def __setstate__(self, state):
|
||
"""Unpack from serialized format."""
|
||
from ..epochs import Epochs
|
||
from ..evoked import Evoked
|
||
from ..io import Raw
|
||
|
||
self._method = state["method"]
|
||
self._data = state["data"]
|
||
self._freqs = state["freqs"]
|
||
self._dims = state["dims"]
|
||
self._sfreq = state["sfreq"]
|
||
self.info = Info(**state["info"])
|
||
self._data_type = state["data_type"]
|
||
self._nave = state.get("nave") # objs saved before #11282 won't have `nave`
|
||
self._weights = state.get("weights") # objs saved before #12747 won't have
|
||
self.preload = True
|
||
# instance type
|
||
inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Array=np.ndarray)
|
||
self._inst_type = inst_types[state["inst_type_str"]]
|
||
|
||
def __repr__(self):
|
||
"""Build string representation of the Spectrum object."""
|
||
inst_type_str = _get_instance_type_string(self)
|
||
# shape & dimension names
|
||
dims = " × ".join(
|
||
[f"{dim[0]} {dim[1]}s" for dim in zip(self.shape, self._dims)]
|
||
)
|
||
freq_range = f"{self.freqs[0]:0.1f}-{self.freqs[-1]:0.1f} Hz"
|
||
return (
|
||
f"<{self._data_type} (from {inst_type_str}, "
|
||
f"{self.method} method) | {dims}, {freq_range}>"
|
||
)
|
||
|
||
@repr_html
|
||
def _repr_html_(self, caption=None):
|
||
"""Build HTML representation of the Spectrum object."""
|
||
inst_type_str = _get_instance_type_string(self)
|
||
units = [f"{ch_type}: {unit}" for ch_type, unit in self.units().items()]
|
||
t = _get_html_template("repr", "spectrum.html.jinja")
|
||
t = t.render(spectrum=self, inst_type=inst_type_str, units=units)
|
||
return t
|
||
|
||
def _check_values(self):
|
||
"""Check PSD results for correct shape and bad values."""
|
||
assert len(self._dims) == self._data.ndim, (self._dims, self._data.ndim)
|
||
assert self._data.shape == self._shape
|
||
# TODO: should this be more fine-grained (report "chan X in epoch Y")?
|
||
ch_dim = self._dims.index("channel")
|
||
dims = list(range(self._data.ndim))
|
||
dims.pop(ch_dim)
|
||
# take min() across all but the channel axis
|
||
# (if the abs becomes memory intensive we could iterate over channels)
|
||
use_data = self._data
|
||
if use_data.dtype.kind == "c":
|
||
use_data = np.abs(use_data)
|
||
bad_value = use_data.min(axis=tuple(dims)) == 0
|
||
bad_value &= ~np.isin(self.ch_names, self.info["bads"])
|
||
if bad_value.any():
|
||
chs = np.array(self.ch_names)[bad_value].tolist()
|
||
s = _pl(bad_value.sum())
|
||
warn(f'Zero value in spectrum for channel{s} {", ".join(chs)}', UserWarning)
|
||
|
||
def _returns_complex_tapers(self, **method_kw):
|
||
return self.method == "multitaper" and method_kw.get("output") == "complex"
|
||
|
||
def _compute_spectra(self, data, fmin, fmax, n_jobs, method_kw, verbose):
|
||
# make the spectra
|
||
result = self._psd_func(
|
||
data, self.sfreq, fmin=fmin, fmax=fmax, n_jobs=n_jobs, verbose=verbose
|
||
)
|
||
# assign ._data (handling unaggregated multitaper output)
|
||
if self._returns_complex_tapers(**method_kw):
|
||
fourier_coefs, freqs, weights = result
|
||
self._data = fourier_coefs
|
||
self._weights = weights
|
||
else:
|
||
psds, freqs = result
|
||
self._data = psds
|
||
self._weights = None
|
||
# assign properties (._data already assigned above)
|
||
self._freqs = freqs
|
||
# this is *expected* shape, it gets asserted later in _check_values()
|
||
# (and then deleted afterwards)
|
||
self._shape = (len(self.ch_names), len(self.freqs))
|
||
# append n_welch_segments (use "" as .get() default since None considered valid)
|
||
if method_kw.get("average", "") in (None, False):
|
||
n_welch_segments = _compute_n_welch_segments(data.shape[-1], method_kw)
|
||
self._shape += (n_welch_segments,)
|
||
# insert n_tapers
|
||
if self._returns_complex_tapers(**method_kw):
|
||
self._shape = self._shape[:-1] + (self._weights.size,) + self._shape[-1:]
|
||
# we don't need these anymore, and they make save/load harder
|
||
del self._picks
|
||
del self._psd_func
|
||
del self._time_mask
|
||
|
||
@property
|
||
def _detrend_picks(self):
|
||
"""Provide compatibility with __iter__."""
|
||
return list()
|
||
|
||
@property
|
||
def ch_names(self):
|
||
return self.info["ch_names"]
|
||
|
||
@property
|
||
def data(self):
|
||
return self._data
|
||
|
||
@property
|
||
def freqs(self):
|
||
return self._freqs
|
||
|
||
@property
|
||
def method(self):
|
||
return self._method
|
||
|
||
@property
|
||
def nave(self):
|
||
return self._nave
|
||
|
||
@property
|
||
def weights(self):
|
||
return self._weights
|
||
|
||
@property
|
||
def sfreq(self):
|
||
return self._sfreq
|
||
|
||
@property
|
||
def shape(self):
|
||
return self._data.shape
|
||
|
||
def copy(self):
|
||
"""Return copy of the Spectrum instance.
|
||
|
||
Returns
|
||
-------
|
||
spectrum : instance of Spectrum
|
||
A copy of the object.
|
||
"""
|
||
return deepcopy(self)
|
||
|
||
@fill_doc
|
||
def get_data(
|
||
self, picks=None, exclude="bads", fmin=0, fmax=np.inf, return_freqs=False
|
||
):
|
||
"""Get spectrum data in NumPy array format.
|
||
|
||
Parameters
|
||
----------
|
||
%(picks_good_data_noref)s
|
||
%(exclude_spectrum_get_data)s
|
||
%(fmin_fmax_psd)s
|
||
return_freqs : bool
|
||
Whether to return the frequency bin values for the requested
|
||
frequency range. Default is ``False``.
|
||
|
||
Returns
|
||
-------
|
||
data : array
|
||
The requested data in a NumPy array.
|
||
freqs : array
|
||
The frequency values for the requested range. Only returned if
|
||
``return_freqs`` is ``True``.
|
||
"""
|
||
picks = _picks_to_idx(
|
||
self.info, picks, "data_or_ica", exclude=exclude, with_ref_meg=False
|
||
)
|
||
fmin_idx = np.searchsorted(self.freqs, fmin)
|
||
fmax_idx = np.searchsorted(self.freqs, fmax, side="right")
|
||
freq_picks = np.arange(fmin_idx, fmax_idx)
|
||
freq_axis = self._dims.index("freq")
|
||
chan_axis = self._dims.index("channel")
|
||
# normally there's a risk of np.take reducing array dimension if there
|
||
# were only one channel or frequency selected, but `_picks_to_idx`
|
||
# always returns an array of picks, and np.arange always returns an
|
||
# array of freq bin indices, so we're safe; the result will always be
|
||
# 2D.
|
||
data = self._data.take(picks, chan_axis).take(freq_picks, freq_axis)
|
||
if return_freqs:
|
||
freqs = self._freqs[fmin_idx:fmax_idx]
|
||
return (data, freqs)
|
||
return data
|
||
|
||
@fill_doc
|
||
def plot(
|
||
self,
|
||
*,
|
||
picks=None,
|
||
average=False,
|
||
dB=True,
|
||
amplitude=False,
|
||
xscale="linear",
|
||
ci="sd",
|
||
ci_alpha=0.3,
|
||
color="black",
|
||
alpha=None,
|
||
spatial_colors=True,
|
||
sphere=None,
|
||
exclude=(),
|
||
axes=None,
|
||
show=True,
|
||
):
|
||
"""%(plot_psd_doc)s.
|
||
|
||
Parameters
|
||
----------
|
||
%(picks_all_data_noref)s
|
||
|
||
.. versionchanged:: 1.5
|
||
In version 1.5, the default behavior changed so that all
|
||
:term:`data channels` (not just "good" data channels) are shown by
|
||
default.
|
||
average : bool
|
||
Whether to average across channels before plotting. If ``True``, interactive
|
||
plotting of scalp topography is disabled, and parameters ``ci`` and
|
||
``ci_alpha`` control the style of the confidence band around the mean.
|
||
Default is ``False``.
|
||
%(dB_spectrum_plot)s
|
||
amplitude : bool
|
||
Whether to plot an amplitude spectrum (``True``) or power spectrum
|
||
(``False``).
|
||
|
||
.. versionchanged:: 1.8
|
||
In version 1.8, the default changed to ``amplitude=False``.
|
||
%(xscale_plot_psd)s
|
||
ci : float | 'sd' | 'range' | None
|
||
Type of confidence band drawn around the mean when ``average=True``. If
|
||
``'sd'`` the band spans ±1 standard deviation across channels. If
|
||
``'range'`` the band spans the range across channels at each frequency. If a
|
||
:class:`float`, it indicates the (bootstrapped) confidence interval to
|
||
display, and must satisfy ``0 < ci <= 100``. If ``None``, no band is drawn.
|
||
Default is ``sd``.
|
||
ci_alpha : float
|
||
Opacity of the confidence band. Must satisfy ``0 <= ci_alpha <= 1``. Default
|
||
is 0.3.
|
||
%(color_plot_psd)s
|
||
alpha : float | None
|
||
Opacity of the spectrum line(s). If :class:`float`, must satisfy
|
||
``0 <= alpha <= 1``. If ``None``, opacity will be ``1`` when
|
||
``average=True`` and ``0.1`` when ``average=False``. Default is ``None``.
|
||
%(spatial_colors_psd)s
|
||
%(sphere_topomap_auto)s
|
||
%(exclude_spectrum_plot)s
|
||
|
||
.. versionchanged:: 1.5
|
||
In version 1.5, the default behavior changed from ``exclude='bads'`` to
|
||
``exclude=()``.
|
||
%(axes_spectrum_plot_topomap)s
|
||
%(show)s
|
||
|
||
Returns
|
||
-------
|
||
fig : instance of matplotlib.figure.Figure
|
||
Figure with spectra plotted in separate subplots for each channel type.
|
||
"""
|
||
# Must nest this _mpl_figure import because of the BACKEND global
|
||
# stuff
|
||
from ..viz._mpl_figure import _line_figure, _split_picks_by_type
|
||
|
||
# arg checking
|
||
ci = _check_ci(ci)
|
||
_check_option("xscale", xscale, ("log", "linear"))
|
||
sphere = _check_sphere(sphere, self.info)
|
||
# defaults
|
||
scalings = _handle_default("scalings", None)
|
||
titles = _handle_default("titles", None)
|
||
units = _handle_default("units", None)
|
||
|
||
_validate_type(amplitude, bool, "amplitude")
|
||
estimate = "amplitude" if amplitude else "power"
|
||
|
||
logger.info(f"Plotting {estimate} spectral density ({dB=}).")
|
||
|
||
# split picks by channel type
|
||
picks = _picks_to_idx(
|
||
self.info, picks, "data", exclude=exclude, with_ref_meg=False
|
||
)
|
||
(picks_list, units_list, scalings_list, titles_list) = _split_picks_by_type(
|
||
self, picks, units, scalings, titles
|
||
)
|
||
# prepare data (e.g. aggregate across dims, convert complex to power)
|
||
psd_list = [
|
||
self._prepare_data_for_plot(
|
||
self._data.take(_p, axis=self._dims.index("channel"))
|
||
)
|
||
for _p in picks_list
|
||
]
|
||
# initialize figure
|
||
fig, axes = _line_figure(self, axes, picks=picks)
|
||
# don't add ylabels & titles if figure has unexpected number of axes
|
||
make_label = len(axes) == len(fig.axes)
|
||
# Plot Frequency [Hz] xlabel only on the last axis
|
||
xlabels_list = [False] * (len(axes) - 1) + [True]
|
||
# plot
|
||
_plot_psd(
|
||
self,
|
||
fig,
|
||
self.freqs,
|
||
psd_list,
|
||
picks_list,
|
||
titles_list,
|
||
units_list,
|
||
scalings_list,
|
||
axes,
|
||
make_label,
|
||
color,
|
||
area_mode=ci,
|
||
area_alpha=ci_alpha,
|
||
dB=dB,
|
||
estimate=estimate,
|
||
average=average,
|
||
spatial_colors=spatial_colors,
|
||
xscale=xscale,
|
||
line_alpha=alpha,
|
||
sphere=sphere,
|
||
xlabels_list=xlabels_list,
|
||
)
|
||
plt_show(show, fig)
|
||
return fig
|
||
|
||
@fill_doc
|
||
def plot_topo(
|
||
self,
|
||
*,
|
||
dB=True,
|
||
layout=None,
|
||
color="w",
|
||
fig_facecolor="k",
|
||
axis_facecolor="k",
|
||
axes=None,
|
||
block=False,
|
||
show=True,
|
||
):
|
||
"""Plot power spectral density, separately for each channel.
|
||
|
||
Parameters
|
||
----------
|
||
%(dB_spectrum_plot_topo)s
|
||
%(layout_spectrum_plot_topo)s
|
||
%(color_spectrum_plot_topo)s
|
||
%(fig_facecolor)s
|
||
%(axis_facecolor)s
|
||
%(axes_spectrum_plot_topo)s
|
||
%(block)s
|
||
%(show)s
|
||
|
||
Returns
|
||
-------
|
||
fig : instance of matplotlib.figure.Figure
|
||
Figure distributing one image per channel across sensor topography.
|
||
"""
|
||
if layout is None:
|
||
layout = find_layout(self.info)
|
||
|
||
psds, freqs = self.get_data(return_freqs=True)
|
||
# prepare data (e.g. aggregate across dims, convert complex to power)
|
||
psds = self._prepare_data_for_plot(psds)
|
||
if dB:
|
||
psds = 10 * np.log10(psds)
|
||
y_label = "dB"
|
||
else:
|
||
y_label = "Power"
|
||
show_func = partial(
|
||
_plot_timeseries_unified, data=[psds], color=color, times=[freqs]
|
||
)
|
||
click_func = partial(_plot_timeseries, data=[psds], color=color, times=[freqs])
|
||
picks = _pick_data_channels(self.info)
|
||
info = pick_info(self.info, picks)
|
||
fig = _plot_topo(
|
||
info,
|
||
times=freqs,
|
||
show_func=show_func,
|
||
click_func=click_func,
|
||
layout=layout,
|
||
axis_facecolor=axis_facecolor,
|
||
fig_facecolor=fig_facecolor,
|
||
x_label="Frequency (Hz)",
|
||
unified=True,
|
||
y_label=y_label,
|
||
axes=axes,
|
||
)
|
||
plt_show(show, block=block)
|
||
return fig
|
||
|
||
@fill_doc
|
||
def plot_topomap(
|
||
self,
|
||
bands=None,
|
||
ch_type=None,
|
||
*,
|
||
normalize=False,
|
||
agg_fun=None,
|
||
dB=False,
|
||
sensors=True,
|
||
show_names=False,
|
||
mask=None,
|
||
mask_params=None,
|
||
contours=6,
|
||
outlines="head",
|
||
sphere=None,
|
||
image_interp=_INTERPOLATION_DEFAULT,
|
||
extrapolate=_EXTRAPOLATE_DEFAULT,
|
||
border=_BORDER_DEFAULT,
|
||
res=64,
|
||
size=1,
|
||
cmap=None,
|
||
vlim=(None, None),
|
||
cnorm=None,
|
||
colorbar=True,
|
||
cbar_fmt="auto",
|
||
units=None,
|
||
axes=None,
|
||
show=True,
|
||
):
|
||
"""Plot scalp topography of PSD for chosen frequency bands.
|
||
|
||
Parameters
|
||
----------
|
||
%(bands_psd_topo)s
|
||
%(ch_type_topomap_psd)s
|
||
%(normalize_psd_topo)s
|
||
%(agg_fun_psd_topo)s
|
||
%(dB_plot_topomap)s
|
||
%(sensors_topomap)s
|
||
%(show_names_topomap)s
|
||
%(mask_evoked_topomap)s
|
||
%(mask_params_topomap)s
|
||
%(contours_topomap)s
|
||
%(outlines_topomap)s
|
||
%(sphere_topomap_auto)s
|
||
%(image_interp_topomap)s
|
||
%(extrapolate_topomap)s
|
||
%(border_topomap)s
|
||
%(res_topomap)s
|
||
%(size_topomap)s
|
||
%(cmap_topomap)s
|
||
%(vlim_plot_topomap_psd)s
|
||
%(cnorm)s
|
||
%(colorbar_topomap)s
|
||
%(cbar_fmt_topomap_psd)s
|
||
%(units_topomap)s
|
||
%(axes_spectrum_plot_topomap)s
|
||
%(show)s
|
||
|
||
Returns
|
||
-------
|
||
fig : instance of Figure
|
||
Figure showing one scalp topography per frequency band.
|
||
"""
|
||
ch_type = _get_plot_ch_type(self, ch_type)
|
||
if units is None:
|
||
units = _handle_default("units", None)
|
||
unit = units[ch_type] if hasattr(units, "keys") else units
|
||
scalings = _handle_default("scalings", None)
|
||
scaling = scalings[ch_type]
|
||
|
||
(
|
||
picks,
|
||
pos,
|
||
merge_channels,
|
||
names,
|
||
ch_type,
|
||
sphere,
|
||
clip_origin,
|
||
) = _prepare_topomap_plot(self, ch_type, sphere=sphere)
|
||
outlines = _make_head_outlines(sphere, pos, outlines, clip_origin)
|
||
|
||
psds, freqs = self.get_data(picks=picks, return_freqs=True)
|
||
# prepare data (e.g. aggregate across dims, convert complex to power)
|
||
psds = self._prepare_data_for_plot(psds)
|
||
psds *= scaling**2
|
||
|
||
if merge_channels:
|
||
psds, names = _merge_ch_data(psds, ch_type, names, method="mean")
|
||
|
||
names = _prepare_sensor_names(names, show_names)
|
||
return plot_psds_topomap(
|
||
psds=psds,
|
||
freqs=freqs,
|
||
pos=pos,
|
||
bands=bands,
|
||
ch_type=ch_type,
|
||
normalize=normalize,
|
||
agg_fun=agg_fun,
|
||
dB=dB,
|
||
sensors=sensors,
|
||
names=names,
|
||
mask=mask,
|
||
mask_params=mask_params,
|
||
contours=contours,
|
||
outlines=outlines,
|
||
sphere=sphere,
|
||
image_interp=image_interp,
|
||
extrapolate=extrapolate,
|
||
border=border,
|
||
res=res,
|
||
size=size,
|
||
cmap=cmap,
|
||
vlim=vlim,
|
||
cnorm=cnorm,
|
||
colorbar=colorbar,
|
||
cbar_fmt=cbar_fmt,
|
||
unit=unit,
|
||
axes=axes,
|
||
show=show,
|
||
)
|
||
|
||
def _prepare_data_for_plot(self, data):
|
||
# handle unaggregated Welch
|
||
if "segment" in self._dims:
|
||
logger.info("Aggregating Welch estimates (median) before plotting...")
|
||
data = np.nanmedian(data, axis=self._dims.index("segment"))
|
||
# handle unaggregated multitaper (also handles complex -> power)
|
||
elif "taper" in self._dims:
|
||
logger.info("Aggregating multitaper estimates before plotting...")
|
||
data = _psd_from_mt(data, self.weights)
|
||
|
||
# handle complex data (should only be Welch remaining)
|
||
if np.iscomplexobj(data):
|
||
data = (data * data.conj()).real # Scaling may be slightly off
|
||
|
||
# handle epochs
|
||
if "epoch" in self._dims:
|
||
# XXX TODO FIXME decide how to properly aggregate across repeated
|
||
# measures (epochs) and non-repeated but correlated measures
|
||
# (channels) when calculating stddev or a CI. For across-channel
|
||
# aggregation, doi:10.1007/s10162-012-0321-8 used hotellings T**2
|
||
# with a correction factor that estimated data rank using monte
|
||
# carlo simulations; seems like we could use our own data rank
|
||
# estimation methods to similar effect. Their exact approach used
|
||
# complex spectra though, here we've already converted to power;
|
||
# not sure if that makes an important difference? Anyway that
|
||
# aggregation would need to happen in the _plot_psd function
|
||
# though, not here... for now we just average like we always did.
|
||
|
||
# only log message if averaging will actually have an effect
|
||
if data.shape[0] > 1:
|
||
logger.info("Averaging across epochs before plotting...")
|
||
# epoch axis should always be the first axis
|
||
data = data.mean(axis=0)
|
||
|
||
return data
|
||
|
||
@verbose
|
||
def save(self, fname, *, overwrite=False, verbose=None):
|
||
"""Save spectrum data to disk (in HDF5 format).
|
||
|
||
Parameters
|
||
----------
|
||
fname : path-like
|
||
Path of file to save to.
|
||
%(overwrite)s
|
||
%(verbose)s
|
||
|
||
See Also
|
||
--------
|
||
mne.time_frequency.read_spectrum
|
||
"""
|
||
_, write_hdf5 = _import_h5io_funcs()
|
||
check_fname(fname, "spectrum", (".h5", ".hdf5"))
|
||
fname = _check_fname(fname, overwrite=overwrite, verbose=verbose)
|
||
out = self.__getstate__()
|
||
write_hdf5(fname, out, overwrite=overwrite, title="mnepython")
|
||
|
||
@verbose
|
||
def to_data_frame(
|
||
self, picks=None, index=None, copy=True, long_format=False, *, verbose=None
|
||
):
|
||
"""Export data in tabular structure as a pandas DataFrame.
|
||
|
||
Channels are converted to columns in the DataFrame. By default,
|
||
an additional column "freq" is added, unless ``index='freq'``
|
||
(in which case frequency values form the DataFrame's index).
|
||
|
||
Parameters
|
||
----------
|
||
%(picks_all)s
|
||
index : str | list of str | None
|
||
Kind of index to use for the DataFrame. If ``None``, a sequential
|
||
integer index (:class:`pandas.RangeIndex`) will be used. If a
|
||
:class:`str`, a :class:`pandas.Index` will be used (see Notes). If
|
||
a list of two or more string values, a :class:`pandas.MultiIndex`
|
||
will be used. Defaults to ``None``.
|
||
%(copy_df)s
|
||
%(long_format_df_spe)s
|
||
%(verbose)s
|
||
|
||
Returns
|
||
-------
|
||
%(df_return)s
|
||
|
||
Notes
|
||
-----
|
||
Valid values for ``index`` depend on whether the Spectrum was created
|
||
from continuous data (:class:`~mne.io.Raw`, :class:`~mne.Evoked`) or
|
||
discontinuous data (:class:`~mne.Epochs`). For continuous data, only
|
||
``None`` or ``'freq'`` is supported. For discontinuous data, additional
|
||
valid values are ``'epoch'`` and ``'condition'``, or a :class:`list`
|
||
comprising some of the valid string values (e.g.,
|
||
``['freq', 'epoch']``).
|
||
"""
|
||
# check pandas once here, instead of in each private utils function
|
||
pd = _check_pandas_installed() # noqa
|
||
# triage for Epoch-derived or unaggregated spectra
|
||
from_epo = _get_instance_type_string(self) == "Epochs"
|
||
unagg_welch = "segment" in self._dims
|
||
unagg_mt = "taper" in self._dims
|
||
# arg checking
|
||
valid_index_args = ["freq"]
|
||
if from_epo:
|
||
valid_index_args += ["epoch", "condition"]
|
||
index = _check_pandas_index_arguments(index, valid_index_args)
|
||
# get data
|
||
picks = _picks_to_idx(self.info, picks, "all", exclude=())
|
||
data = self.get_data(picks)
|
||
if copy:
|
||
data = data.copy()
|
||
# reshape
|
||
if unagg_mt:
|
||
data = np.moveaxis(data, self._dims.index("freq"), -2)
|
||
if from_epo:
|
||
n_epochs, n_picks, n_freqs = data.shape[:3]
|
||
else:
|
||
n_epochs, n_picks, n_freqs = (1,) + data.shape[:2]
|
||
n_segs = data.shape[-1] if unagg_mt or unagg_welch else 1
|
||
data = np.moveaxis(data, self._dims.index("channel"), -1)
|
||
# at this point, should be ([epoch], freq, [segment/taper], channel)
|
||
data = data.reshape(n_epochs * n_freqs * n_segs, n_picks)
|
||
# prepare extra columns / multiindex
|
||
mindex = list()
|
||
default_index = list()
|
||
if from_epo:
|
||
rev_event_id = {v: k for k, v in self.event_id.items()}
|
||
_conds = [rev_event_id[k] for k in self.events[:, 2]]
|
||
conditions = np.repeat(_conds, n_freqs * n_segs)
|
||
epoch_nums = np.repeat(self.selection, n_freqs * n_segs)
|
||
mindex.extend([("condition", conditions), ("epoch", epoch_nums)])
|
||
default_index.extend(["condition", "epoch"])
|
||
freqs = np.tile(np.repeat(self.freqs, n_segs), n_epochs)
|
||
mindex.append(("freq", freqs))
|
||
default_index.append("freq")
|
||
if unagg_mt or unagg_welch:
|
||
name = "taper" if unagg_mt else "segment"
|
||
seg_nums = np.tile(np.arange(n_segs), n_epochs * n_freqs)
|
||
mindex.append((name, seg_nums))
|
||
default_index.append(name)
|
||
# build DataFrame
|
||
df = _build_data_frame(
|
||
self, data, picks, long_format, mindex, index, default_index=default_index
|
||
)
|
||
return df
|
||
|
||
def units(self, latex=False):
|
||
"""Get the spectrum units for each channel type.
|
||
|
||
Parameters
|
||
----------
|
||
latex : bool
|
||
Whether to format the unit strings as LaTeX. Default is ``False``.
|
||
|
||
Returns
|
||
-------
|
||
units : dict
|
||
Mapping from channel type to a string representation of the units
|
||
for that channel type.
|
||
"""
|
||
units = _handle_default("si_units", None)
|
||
return {
|
||
ch_type: _format_units_psd(units[ch_type], power=True, latex=latex)
|
||
for ch_type in sorted(self.get_channel_types(unique=True))
|
||
}
|
||
|
||
|
||
@fill_doc
|
||
class Spectrum(BaseSpectrum):
|
||
"""Data object for spectral representations of continuous data.
|
||
|
||
.. warning:: The preferred means of creating Spectrum objects from
|
||
continuous or averaged data is via the instance methods
|
||
:meth:`mne.io.Raw.compute_psd` or
|
||
:meth:`mne.Evoked.compute_psd`. Direct class instantiation
|
||
is not supported.
|
||
|
||
Parameters
|
||
----------
|
||
inst : instance of Raw or Evoked
|
||
The data from which to compute the frequency spectrum.
|
||
%(method_psd_auto)s
|
||
``'auto'`` (default) uses Welch's method for continuous data
|
||
and multitaper for :class:`~mne.Evoked` data.
|
||
%(fmin_fmax_psd)s
|
||
%(tmin_tmax_psd)s
|
||
%(picks_good_data_noref)s
|
||
%(exclude_psd)s
|
||
%(proj_psd)s
|
||
%(remove_dc)s
|
||
%(reject_by_annotation_psd)s
|
||
%(n_jobs)s
|
||
%(verbose)s
|
||
%(method_kw_psd)s
|
||
|
||
Attributes
|
||
----------
|
||
ch_names : list
|
||
The channel names.
|
||
freqs : array
|
||
Frequencies at which the amplitude, power, or fourier coefficients
|
||
have been computed.
|
||
%(info_not_none)s
|
||
method : ``'welch'``| ``'multitaper'``
|
||
The method used to compute the spectrum.
|
||
nave : int | None
|
||
The number of trials averaged together when generating the spectrum. ``None``
|
||
indicates no averaging is known to have occurred.
|
||
weights : array | None
|
||
The weights for each taper. Only present if spectra computed with
|
||
``method='multitaper'`` and ``output='complex'``.
|
||
|
||
.. versionadded:: 1.8
|
||
|
||
See Also
|
||
--------
|
||
EpochsSpectrum
|
||
SpectrumArray
|
||
mne.io.Raw.compute_psd
|
||
mne.Epochs.compute_psd
|
||
mne.Evoked.compute_psd
|
||
|
||
References
|
||
----------
|
||
.. footbibliography::
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
inst,
|
||
method,
|
||
fmin,
|
||
fmax,
|
||
tmin,
|
||
tmax,
|
||
picks,
|
||
exclude,
|
||
proj,
|
||
remove_dc,
|
||
reject_by_annotation,
|
||
*,
|
||
n_jobs,
|
||
verbose=None,
|
||
**method_kw,
|
||
):
|
||
from ..io import BaseRaw
|
||
|
||
# triage reading from file
|
||
if isinstance(inst, dict):
|
||
self.__setstate__(inst)
|
||
return
|
||
# do the basic setup
|
||
super().__init__(
|
||
inst,
|
||
method,
|
||
fmin,
|
||
fmax,
|
||
tmin,
|
||
tmax,
|
||
picks,
|
||
exclude,
|
||
proj,
|
||
remove_dc,
|
||
n_jobs=n_jobs,
|
||
verbose=verbose,
|
||
**method_kw,
|
||
)
|
||
# get just the data we want
|
||
if isinstance(self.inst, BaseRaw):
|
||
start, stop = np.where(self._time_mask)[0][[0, -1]]
|
||
rba = "NaN" if reject_by_annotation else None
|
||
data = self.inst.get_data(
|
||
self._picks, start, stop + 1, reject_by_annotation=rba
|
||
)
|
||
if np.any(np.isnan(data)) and method == "multitaper":
|
||
raise NotImplementedError(
|
||
'Cannot use method="multitaper" when reject_by_annotation=True. '
|
||
'Please use method="welch" instead.'
|
||
)
|
||
|
||
else: # Evoked
|
||
data = self.inst.data[self._picks][:, self._time_mask]
|
||
# set nave
|
||
self._nave = getattr(inst, "nave", None)
|
||
# compute the spectra
|
||
self._compute_spectra(data, fmin, fmax, n_jobs, method_kw, verbose)
|
||
# check for correct shape and bad values
|
||
self._check_values()
|
||
del self._shape # calculated from self._data henceforth
|
||
# save memory
|
||
del self.inst
|
||
|
||
def __getitem__(self, item):
|
||
"""Get Spectrum data.
|
||
|
||
Parameters
|
||
----------
|
||
item : int | slice | array-like
|
||
Indexing is similar to a :class:`NumPy array<numpy.ndarray>`; see
|
||
Notes.
|
||
|
||
Returns
|
||
-------
|
||
%(getitem_spectrum_return)s
|
||
|
||
Notes
|
||
-----
|
||
Integer-, list-, and slice-based indexing is possible:
|
||
|
||
- ``spectrum[0]`` gives all frequency bins in the first channel
|
||
- ``spectrum[:3]`` gives all frequency bins in the first 3 channels
|
||
- ``spectrum[[0, 2], 5]`` gives the value in the sixth frequency bin of
|
||
the first and third channels
|
||
- ``spectrum[(4, 7)]`` is the same as ``spectrum[4, 7]``.
|
||
|
||
.. note::
|
||
|
||
Unlike :class:`~mne.io.Raw` objects (which returns a tuple of the
|
||
requested data values and the corresponding times), accessing
|
||
:class:`~mne.time_frequency.Spectrum` values via subscript does
|
||
**not** return the corresponding frequency bin values. If you need
|
||
them, use ``spectrum.freqs[freq_indices]`` or
|
||
``spectrum.get_data(..., return_freqs=True)``.
|
||
"""
|
||
from ..io import BaseRaw
|
||
|
||
self._parse_get_set_params = partial(BaseRaw._parse_get_set_params, self)
|
||
return BaseRaw._getitem(self, item, return_times=False)
|
||
|
||
|
||
def _check_data_shape(data, info, freqs, dim_names, weights, is_epoched):
|
||
if data.ndim != len(dim_names):
|
||
raise ValueError(
|
||
f"Expected data to have {len(dim_names)} dimensions, got {data.ndim}."
|
||
)
|
||
|
||
allowed_dims = ["epoch", "channel", "freq", "segment", "taper"]
|
||
if not is_epoched:
|
||
allowed_dims.remove("epoch")
|
||
# TODO maybe we should be nice and allow plural versions of each dimname?
|
||
for dim in dim_names:
|
||
_check_option("dim_names", dim, allowed_dims)
|
||
if "channel" not in dim_names or "freq" not in dim_names:
|
||
raise ValueError("Both 'channel' and 'freq' must be present in `dim_names`.")
|
||
|
||
if list(dim_names).index("channel") != int(is_epoched):
|
||
raise ValueError(
|
||
f"'channel' must be the {'second' if is_epoched else 'first'} dimension of "
|
||
"the data."
|
||
)
|
||
want_n_chan = _pick_data_channels(info).size
|
||
got_n_chan = data.shape[list(dim_names).index("channel")]
|
||
if got_n_chan != want_n_chan:
|
||
raise ValueError(
|
||
f"The number of channels in `data` ({got_n_chan}) must match the number of "
|
||
f"good data channels in `info` ({want_n_chan})."
|
||
)
|
||
|
||
# given we limit max array size and ensure channel & freq dims present, only one of
|
||
# taper or segment can be present
|
||
if "taper" in dim_names:
|
||
if dim_names[-2] != "taper": # _psd_from_mt assumes this (called when plotting)
|
||
raise ValueError(
|
||
"'taper' must be the second to last dimension of the data."
|
||
)
|
||
# expect weights for each taper
|
||
actual = None if weights is None else weights.size
|
||
expected = data.shape[list(dim_names).index("taper")]
|
||
if actual != expected:
|
||
raise ValueError(
|
||
f"Expected size of `weights` to be {expected} to match 'n_tapers' in "
|
||
f"`data`, got {actual}."
|
||
)
|
||
elif "segment" in dim_names and dim_names[-1] != "segment":
|
||
raise ValueError("'segment' must be the last dimension of the data.")
|
||
|
||
# freq being in wrong position ruled out by above checks
|
||
want_n_freq = freqs.size
|
||
got_n_freq = data.shape[list(dim_names).index("freq")]
|
||
if got_n_freq != want_n_freq:
|
||
raise ValueError(
|
||
f"The number of frequencies in `data` ({got_n_freq}) must match the number "
|
||
f"of elements in `freqs` ({want_n_freq})."
|
||
)
|
||
|
||
|
||
@fill_doc
|
||
class SpectrumArray(Spectrum):
|
||
"""Data object for precomputed spectral data (in NumPy array format).
|
||
|
||
Parameters
|
||
----------
|
||
data : ndarray, shape (n_channels, [n_tapers], n_freqs, [n_segments])
|
||
The spectra for each channel.
|
||
%(info_not_none)s
|
||
%(freqs_tfr_array)s
|
||
dim_names : tuple of str
|
||
The name of the dimensions in the data, in the order they occur. Must contain
|
||
``'channel'`` and ``'freq'``; if data are unaggregated estimates, also include
|
||
either a ``'segment'`` (e.g., Welch-like algorithms) or ``'taper'`` (e.g.,
|
||
multitaper algorithms) dimension. If including ``'taper'``, you should also pass
|
||
a ``weights`` parameter.
|
||
|
||
.. versionadded:: 1.8
|
||
weights : ndarray | None
|
||
Weights for the ``'taper'`` dimension, if present (see ``dim_names``).
|
||
|
||
.. versionadded:: 1.8
|
||
%(verbose)s
|
||
|
||
See Also
|
||
--------
|
||
mne.create_info
|
||
mne.EvokedArray
|
||
mne.io.RawArray
|
||
EpochsSpectrumArray
|
||
|
||
Notes
|
||
-----
|
||
%(notes_spectrum_array)s
|
||
|
||
.. versionadded:: 1.6
|
||
"""
|
||
|
||
@verbose
|
||
def __init__(
|
||
self,
|
||
data,
|
||
info,
|
||
freqs,
|
||
dim_names=("channel", "freq"),
|
||
weights=None,
|
||
*,
|
||
verbose=None,
|
||
):
|
||
# (channel, [taper], freq, [segment])
|
||
_check_option("data.ndim", data.ndim, (2, 3)) # only allow one extra dimension
|
||
|
||
_check_data_shape(data, info, freqs, dim_names, weights, is_epoched=False)
|
||
|
||
self.__setstate__(
|
||
dict(
|
||
method="unknown",
|
||
data=data,
|
||
sfreq=info["sfreq"],
|
||
dims=dim_names,
|
||
freqs=freqs,
|
||
inst_type_str="Array",
|
||
data_type=(
|
||
"Fourier Coefficients"
|
||
if np.iscomplexobj(data)
|
||
else "Power Spectrum"
|
||
),
|
||
info=info,
|
||
weights=weights,
|
||
)
|
||
)
|
||
|
||
|
||
@fill_doc
|
||
class EpochsSpectrum(BaseSpectrum, GetEpochsMixin):
|
||
"""Data object for spectral representations of epoched data.
|
||
|
||
.. warning:: The preferred means of creating Spectrum objects from Epochs
|
||
is via the instance method :meth:`mne.Epochs.compute_psd`.
|
||
Direct class instantiation is not supported.
|
||
|
||
Parameters
|
||
----------
|
||
inst : instance of Epochs
|
||
The data from which to compute the frequency spectrum.
|
||
%(method_psd)s
|
||
%(fmin_fmax_psd)s
|
||
%(tmin_tmax_psd)s
|
||
%(picks_good_data_noref)s
|
||
%(exclude_psd)s
|
||
%(proj_psd)s
|
||
%(remove_dc)s
|
||
%(n_jobs)s
|
||
%(verbose)s
|
||
%(method_kw_psd)s
|
||
|
||
Attributes
|
||
----------
|
||
ch_names : list
|
||
The channel names.
|
||
freqs : array
|
||
Frequencies at which the amplitude, power, or fourier coefficients
|
||
have been computed.
|
||
%(info_not_none)s
|
||
method : ``'welch'``| ``'multitaper'``
|
||
The method used to compute the spectrum.
|
||
weights : array | None
|
||
The weights for each taper. Only present if spectra computed with
|
||
``method='multitaper'`` and ``output='complex'``.
|
||
|
||
.. versionadded:: 1.8
|
||
|
||
See Also
|
||
--------
|
||
EpochsSpectrumArray
|
||
Spectrum
|
||
mne.Epochs.compute_psd
|
||
|
||
References
|
||
----------
|
||
.. footbibliography::
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
inst,
|
||
method,
|
||
fmin,
|
||
fmax,
|
||
tmin,
|
||
tmax,
|
||
picks,
|
||
exclude,
|
||
proj,
|
||
remove_dc,
|
||
*,
|
||
n_jobs,
|
||
verbose=None,
|
||
**method_kw,
|
||
):
|
||
# triage reading from file
|
||
if isinstance(inst, dict):
|
||
self.__setstate__(inst)
|
||
return
|
||
# do the basic setup
|
||
super().__init__(
|
||
inst,
|
||
method,
|
||
fmin,
|
||
fmax,
|
||
tmin,
|
||
tmax,
|
||
picks,
|
||
exclude,
|
||
proj,
|
||
remove_dc,
|
||
n_jobs=n_jobs,
|
||
verbose=verbose,
|
||
**method_kw,
|
||
)
|
||
# get just the data we want
|
||
data = self.inst._get_data(picks=self._picks, on_empty="raise")[
|
||
:, :, self._time_mask
|
||
]
|
||
# compute the spectra
|
||
self._compute_spectra(data, fmin, fmax, n_jobs, method_kw, verbose)
|
||
self._dims = ("epoch",) + self._dims
|
||
self._shape = (len(self.inst),) + self._shape
|
||
# check for correct shape and bad values
|
||
self._check_values()
|
||
del self._shape
|
||
# we need these for to_data_frame()
|
||
self.event_id = self.inst.event_id.copy()
|
||
self.events = self.inst.events.copy()
|
||
self.selection = self.inst.selection.copy()
|
||
# we need these for __getitem__()
|
||
self.drop_log = deepcopy(self.inst.drop_log)
|
||
self._metadata = self.inst.metadata
|
||
# save memory
|
||
del self.inst
|
||
|
||
def __getitem__(self, item):
|
||
"""Subselect epochs from an EpochsSpectrum.
|
||
|
||
Parameters
|
||
----------
|
||
item : int | slice | array-like | str
|
||
Access options are the same as for :class:`~mne.Epochs` objects,
|
||
see the docstring of :meth:`mne.Epochs.__getitem__` for
|
||
explanation.
|
||
|
||
Returns
|
||
-------
|
||
%(getitem_epochspectrum_return)s
|
||
"""
|
||
return super().__getitem__(item)
|
||
|
||
def __getstate__(self):
|
||
"""Prepare object for serialization."""
|
||
out = super().__getstate__()
|
||
out.update(
|
||
metadata=self._metadata,
|
||
drop_log=self.drop_log,
|
||
event_id=self.event_id,
|
||
events=self.events,
|
||
selection=self.selection,
|
||
)
|
||
return out
|
||
|
||
def __setstate__(self, state):
|
||
"""Unpack from serialized format."""
|
||
super().__setstate__(state)
|
||
self._metadata = state["metadata"]
|
||
self.drop_log = state["drop_log"]
|
||
self.event_id = state["event_id"]
|
||
self.events = state["events"]
|
||
self.selection = state["selection"]
|
||
|
||
def average(self, method="mean"):
|
||
"""Average the spectra across epochs.
|
||
|
||
Parameters
|
||
----------
|
||
method : 'mean' | 'median' | callable
|
||
How to aggregate spectra across epochs. If callable, must take a
|
||
:class:`NumPy array<numpy.ndarray>` of shape
|
||
``(n_epochs, n_channels, n_freqs)`` and return an array of shape
|
||
``(n_channels, n_freqs)``. Default is ``'mean'``.
|
||
|
||
Returns
|
||
-------
|
||
spectrum : instance of Spectrum
|
||
The aggregated spectrum object.
|
||
"""
|
||
_validate_type(method, ("str", "callable"), "method")
|
||
method = _make_combine_callable(
|
||
method, axis=0, valid=("mean", "median"), keepdims=False
|
||
)
|
||
if not callable(method):
|
||
raise ValueError(
|
||
'"method" must be a valid string or callable, '
|
||
f"got a {type(method).__name__} ({method})."
|
||
)
|
||
# averaging unaggregated spectral estimates are not supported
|
||
if "segment" in self._dims:
|
||
raise NotImplementedError(
|
||
"Averaging individual Welch segments across epochs is not "
|
||
"supported. Consider averaging the signals before computing "
|
||
"the Welch spectrum estimates."
|
||
)
|
||
if "taper" in self._dims:
|
||
raise NotImplementedError(
|
||
"Averaging multitaper tapers across epochs is not supported. Consider "
|
||
"averaging the signals before computing the complex spectrum."
|
||
)
|
||
# serialize the object and update data, dims, and data type
|
||
state = super().__getstate__()
|
||
state["nave"] = state["data"].shape[0]
|
||
state["data"] = method(state["data"])
|
||
state["dims"] = state["dims"][1:]
|
||
state["data_type"] = f'Averaged {state["data_type"]}'
|
||
defaults = dict(
|
||
method=None,
|
||
fmin=None,
|
||
fmax=None,
|
||
tmin=None,
|
||
tmax=None,
|
||
picks=None,
|
||
exclude=(),
|
||
proj=None,
|
||
remove_dc=None,
|
||
reject_by_annotation=None,
|
||
n_jobs=None,
|
||
verbose=None,
|
||
)
|
||
return Spectrum(state, **defaults)
|
||
|
||
|
||
@fill_doc
|
||
class EpochsSpectrumArray(EpochsSpectrum):
|
||
"""Data object for precomputed epoched spectral data (in NumPy array format).
|
||
|
||
Parameters
|
||
----------
|
||
data : ndarray, shape (n_epochs, n_channels, [n_tapers], n_freqs, [n_segments])
|
||
The spectra for each channel in each epoch.
|
||
%(info_not_none)s
|
||
%(freqs_tfr_array)s
|
||
%(events_epochs)s
|
||
%(event_id)s
|
||
dim_names : tuple of str
|
||
The name of the dimensions in the data, in the order they occur. Must contain
|
||
``'channel'`` and ``'freq'``; if data are unaggregated estimates, also include
|
||
either a ``'segment'`` (e.g., Welch-like algorithms) or ``'taper'`` (e.g.,
|
||
multitaper algorithms) dimension. If including ``'taper'``, you should also pass
|
||
a ``weights`` parameter.
|
||
|
||
.. versionadded:: 1.8
|
||
weights : ndarray | None
|
||
Weights for the ``'taper'`` dimension, if present (see ``dim_names``).
|
||
|
||
.. versionadded:: 1.8
|
||
%(verbose)s
|
||
|
||
See Also
|
||
--------
|
||
mne.create_info
|
||
mne.EpochsArray
|
||
SpectrumArray
|
||
|
||
Notes
|
||
-----
|
||
%(notes_spectrum_array)s
|
||
|
||
.. versionadded:: 1.6
|
||
"""
|
||
|
||
@verbose
|
||
def __init__(
|
||
self,
|
||
data,
|
||
info,
|
||
freqs,
|
||
events=None,
|
||
event_id=None,
|
||
dim_names=("epoch", "channel", "freq"),
|
||
weights=None,
|
||
*,
|
||
verbose=None,
|
||
):
|
||
# (epoch, channel, [taper], freq, [segment])
|
||
_check_option("data.ndim", data.ndim, (3, 4)) # only allow one extra dimension
|
||
|
||
if list(dim_names).index("epoch") != 0:
|
||
raise ValueError("'epoch' must be the first dimension of `data`.")
|
||
if events is not None and data.shape[0] != events.shape[0]:
|
||
raise ValueError(
|
||
f"The first dimension of `data` ({data.shape[0]}) must match the first "
|
||
f"dimension of `events` ({events.shape[0]})."
|
||
)
|
||
|
||
_check_data_shape(data, info, freqs, dim_names, weights, is_epoched=True)
|
||
|
||
self.__setstate__(
|
||
dict(
|
||
method="unknown",
|
||
data=data,
|
||
sfreq=info["sfreq"],
|
||
dims=dim_names,
|
||
freqs=freqs,
|
||
inst_type_str="Array",
|
||
data_type=(
|
||
"Fourier Coefficients"
|
||
if np.iscomplexobj(data)
|
||
else "Power Spectrum"
|
||
),
|
||
info=info,
|
||
events=events,
|
||
event_id=event_id,
|
||
metadata=None,
|
||
selection=np.arange(data.shape[0]),
|
||
drop_log=tuple(tuple() for _ in range(data.shape[0])),
|
||
weights=weights,
|
||
)
|
||
)
|
||
|
||
|
||
def read_spectrum(fname):
|
||
"""Load a :class:`mne.time_frequency.Spectrum` object from disk.
|
||
|
||
Parameters
|
||
----------
|
||
fname : path-like
|
||
Path to a spectrum file in HDF5 format, which should end with ``.h5`` or
|
||
``.hdf5``.
|
||
|
||
Returns
|
||
-------
|
||
spectrum : instance of Spectrum
|
||
The loaded Spectrum object.
|
||
|
||
See Also
|
||
--------
|
||
mne.time_frequency.Spectrum.save
|
||
"""
|
||
read_hdf5, _ = _import_h5io_funcs()
|
||
_validate_type(fname, "path-like", "fname")
|
||
fname = _check_fname(fname=fname, overwrite="read", must_exist=False)
|
||
# read it in
|
||
hdf5_dict = read_hdf5(fname, title="mnepython")
|
||
defaults = dict(
|
||
method=None,
|
||
fmin=None,
|
||
fmax=None,
|
||
tmin=None,
|
||
tmax=None,
|
||
picks=None,
|
||
exclude=(),
|
||
proj=None,
|
||
remove_dc=None,
|
||
reject_by_annotation=None,
|
||
n_jobs=None,
|
||
verbose=None,
|
||
)
|
||
Klass = EpochsSpectrum if hdf5_dict["inst_type_str"] == "Epochs" else Spectrum
|
||
return Klass(hdf5_dict, **defaults)
|
||
|
||
|
||
def _check_ci(ci):
|
||
ci = "sd" if ci == "std" else ci # be forgiving
|
||
if _is_numeric(ci):
|
||
if not (0 < ci <= 100):
|
||
raise ValueError(f"ci must satisfy 0 < ci <= 100, got {ci}")
|
||
ci /= 100.0
|
||
else:
|
||
_check_option("ci", ci, [None, "sd", "range"])
|
||
return ci
|
||
|
||
|
||
def _compute_n_welch_segments(n_times, method_kw):
|
||
# get default values from psd_array_welch
|
||
_defaults = dict()
|
||
for param in ("n_fft", "n_per_seg", "n_overlap"):
|
||
_defaults[param] = signature(psd_array_welch).parameters[param].default
|
||
# override defaults with user-specified values
|
||
for key, val in _defaults.items():
|
||
_defaults.update({key: method_kw.get(key, val)})
|
||
# sanity check values / replace `None`s with real numbers
|
||
n_fft, n_per_seg, n_overlap = _check_nfft(n_times, **_defaults)
|
||
# compute expected number of segments
|
||
step = n_per_seg - n_overlap
|
||
return (n_times - n_overlap) // step
|
||
|
||
|
||
def _validate_method(method, instance_type):
|
||
"""Convert 'auto' to a real method name, and validate."""
|
||
if method == "auto":
|
||
method = "welch" if instance_type.startswith("Raw") else "multitaper"
|
||
_check_option("method", method, ("welch", "multitaper"))
|
||
return method
|