"""Container classes for spectral data.""" # Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. from copy import deepcopy from functools import partial from inspect import signature import numpy as np from .._fiff.meas_info import ContainsMixin, Info from .._fiff.pick import _pick_data_channels, _picks_to_idx, pick_info from ..channels.channels import UpdateChannelsMixin from ..channels.layout import _merge_ch_data, find_layout from ..defaults import ( _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT, _handle_default, ) from ..html_templates import _get_html_template from ..utils import ( GetEpochsMixin, _build_data_frame, _check_method_kwargs, _check_pandas_index_arguments, _check_pandas_installed, _check_sphere, _time_mask, _validate_type, fill_doc, legacy, logger, object_diff, repr_html, verbose, warn, ) from ..utils.check import ( _check_fname, _check_option, _import_h5io_funcs, _is_numeric, check_fname, ) from ..utils.misc import _pl from ..utils.spectrum import _get_instance_type_string, _split_psd_kwargs from ..viz.topo import _plot_timeseries, _plot_timeseries_unified, _plot_topo from ..viz.topomap import _make_head_outlines, _prepare_topomap_plot, plot_psds_topomap from ..viz.utils import ( _format_units_psd, _get_plot_ch_type, _make_combine_callable, _plot_psd, _prepare_sensor_names, plt_show, ) from .multitaper import _psd_from_mt, psd_array_multitaper from .psd import _check_nfft, psd_array_welch class SpectrumMixin: """Mixin providing spectral plotting methods to sensor-space containers.""" @legacy(alt=".compute_psd().plot()") @verbose def plot_psd( self, fmin=0, fmax=np.inf, tmin=None, tmax=None, picks=None, proj=False, reject_by_annotation=True, *, method="auto", average=False, dB=True, estimate="power", xscale="linear", area_mode="std", area_alpha=0.33, color="black", line_alpha=None, spatial_colors=True, sphere=None, exclude="bads", ax=None, show=True, n_jobs=1, verbose=None, **method_kw, ): """%(plot_psd_doc)s. Parameters ---------- %(fmin_fmax_psd)s %(tmin_tmax_psd)s %(picks_good_data_noref)s %(proj_psd)s %(reject_by_annotation_psd)s %(method_plot_psd_auto)s %(average_plot_psd)s %(dB_plot_psd)s %(estimate_plot_psd)s %(xscale_plot_psd)s %(area_mode_plot_psd)s %(area_alpha_plot_psd)s %(color_plot_psd)s %(line_alpha_plot_psd)s %(spatial_colors_psd)s %(sphere_topomap_auto)s .. versionadded:: 0.22.0 exclude : list of str | 'bads' Channels names to exclude from being shown. If 'bads', the bad channels are excluded. Pass an empty list to plot all channels (including channels marked "bad", if any). .. versionadded:: 0.24.0 %(ax_plot_psd)s %(show)s %(n_jobs)s %(verbose)s %(method_kw_psd)s Returns ------- fig : instance of Figure Figure with frequency spectra of the data channels. Notes ----- %(notes_plot_psd_meth)s """ init_kw, plot_kw = _split_psd_kwargs(plot_fun=Spectrum.plot) return self.compute_psd(**init_kw).plot(**plot_kw) @legacy(alt=".compute_psd().plot_topo()") @verbose def plot_psd_topo( self, tmin=None, tmax=None, fmin=0, fmax=100, proj=False, *, method="auto", dB=True, layout=None, color="w", fig_facecolor="k", axis_facecolor="k", axes=None, block=False, show=True, n_jobs=None, verbose=None, **method_kw, ): """Plot power spectral density, separately for each channel. Parameters ---------- %(tmin_tmax_psd)s %(fmin_fmax_psd_topo)s %(proj_psd)s %(method_plot_psd_auto)s %(dB_spectrum_plot_topo)s %(layout_spectrum_plot_topo)s %(color_spectrum_plot_topo)s %(fig_facecolor)s %(axis_facecolor)s %(axes_spectrum_plot_topo)s %(block)s %(show)s %(n_jobs)s %(verbose)s %(method_kw_psd)s Defaults to ``dict(n_fft=2048)``. Returns ------- fig : instance of matplotlib.figure.Figure Figure distributing one image per channel across sensor topography. """ init_kw, plot_kw = _split_psd_kwargs(plot_fun=Spectrum.plot_topo) return self.compute_psd(**init_kw).plot_topo(**plot_kw) @legacy(alt=".compute_psd().plot_topomap()") @verbose def plot_psd_topomap( self, bands=None, tmin=None, tmax=None, ch_type=None, *, proj=False, method="auto", normalize=False, agg_fun=None, dB=False, sensors=True, show_names=False, mask=None, mask_params=None, contours=0, outlines="head", sphere=None, image_interp=_INTERPOLATION_DEFAULT, extrapolate=_EXTRAPOLATE_DEFAULT, border=_BORDER_DEFAULT, res=64, size=1, cmap=None, vlim=(None, None), cnorm=None, colorbar=True, cbar_fmt="auto", units=None, axes=None, show=True, n_jobs=None, verbose=None, **method_kw, ): """Plot scalp topography of PSD for chosen frequency bands. Parameters ---------- %(bands_psd_topo)s %(tmin_tmax_psd)s %(ch_type_topomap_psd)s %(proj_psd)s %(method_plot_psd_auto)s %(normalize_psd_topo)s %(agg_fun_psd_topo)s %(dB_plot_topomap)s %(sensors_topomap)s %(show_names_topomap)s %(mask_evoked_topomap)s %(mask_params_topomap)s %(contours_topomap)s %(outlines_topomap)s %(sphere_topomap_auto)s %(image_interp_topomap)s %(extrapolate_topomap)s %(border_topomap)s %(res_topomap)s %(size_topomap)s %(cmap_topomap)s %(vlim_plot_topomap_psd)s %(cnorm)s .. versionadded:: 1.2 %(colorbar_topomap)s %(cbar_fmt_topomap_psd)s %(units_topomap)s %(axes_spectrum_plot_topomap)s %(show)s %(n_jobs)s %(verbose)s %(method_kw_psd)s Returns ------- fig : instance of Figure Figure showing one scalp topography per frequency band. """ init_kw, plot_kw = _split_psd_kwargs(plot_fun=Spectrum.plot_topomap) return self.compute_psd(**init_kw).plot_topomap(**plot_kw) def _set_legacy_nfft_default(self, tmin, tmax, method, method_kw): """Update method_kw with legacy n_fft default for plot_psd[_topo](). This method returns ``None`` and has a side effect of (maybe) updating the ``method_kw`` dict. """ if method == "welch" and method_kw.get("n_fft") is None: tm = _time_mask(self.times, tmin, tmax, sfreq=self.info["sfreq"]) method_kw["n_fft"] = min(np.sum(tm), 2048) class BaseSpectrum(ContainsMixin, UpdateChannelsMixin): """Base class for Spectrum and EpochsSpectrum.""" def __init__( self, inst, method, fmin, fmax, tmin, tmax, picks, exclude, proj, remove_dc, *, n_jobs, verbose=None, **method_kw, ): # arg checking self._sfreq = inst.info["sfreq"] if np.isfinite(fmax) and (fmax > self.sfreq / 2): raise ValueError( f"Requested fmax ({fmax} Hz) must not exceed ½ the sampling " f'frequency of the data ({0.5 * inst.info["sfreq"]} Hz).' ) # method self._inst_type = type(inst) method = _validate_method(method, _get_instance_type_string(self)) psd_funcs = dict(welch=psd_array_welch, multitaper=psd_array_multitaper) # triage method and kwargs. partial() doesn't check validity of kwargs, # so we do it manually to save compute time if any are invalid. psd_funcs = dict(welch=psd_array_welch, multitaper=psd_array_multitaper) _check_method_kwargs(psd_funcs[method], method_kw, msg=f'PSD method "{method}"') self._psd_func = partial(psd_funcs[method], remove_dc=remove_dc, **method_kw) # apply proj if desired if proj: inst = inst.copy().apply_proj() self.inst = inst # prep times and picks self._time_mask = _time_mask(inst.times, tmin, tmax, sfreq=self.sfreq) self._picks = _picks_to_idx( inst.info, picks, "data", exclude, with_ref_meg=False ) # add the info object. bads and non-data channels were dropped by # _picks_to_idx() so we update the info accordingly: self.info = pick_info(inst.info, sel=self._picks, copy=True) # assign some attributes self.preload = True # needed for __getitem__, never False self._method = method # self._dims may also get updated by child classes self._dims = ( "channel", "freq", ) if method_kw.get("average", "") in (None, False): self._dims += ("segment",) if self._returns_complex_tapers(**method_kw): self._dims = self._dims[:-1] + ("taper",) + self._dims[-1:] # record data type (for repr and html_repr) self._data_type = ( "Fourier Coefficients" if method_kw.get("output") == "complex" else "Power Spectrum" ) # set nave (child constructor overrides this for Evoked input) self._nave = None def __eq__(self, other): """Test equivalence of two Spectrum instances.""" return object_diff(vars(self), vars(other)) == "" def __getstate__(self): """Prepare object for serialization.""" inst_type_str = _get_instance_type_string(self) out = dict( method=self.method, data=self._data, sfreq=self.sfreq, dims=self._dims, freqs=self.freqs, inst_type_str=inst_type_str, data_type=self._data_type, info=self.info, nave=self.nave, weights=self.weights, ) return out def __setstate__(self, state): """Unpack from serialized format.""" from ..epochs import Epochs from ..evoked import Evoked from ..io import Raw self._method = state["method"] self._data = state["data"] self._freqs = state["freqs"] self._dims = state["dims"] self._sfreq = state["sfreq"] self.info = Info(**state["info"]) self._data_type = state["data_type"] self._nave = state.get("nave") # objs saved before #11282 won't have `nave` self._weights = state.get("weights") # objs saved before #12747 won't have self.preload = True # instance type inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Array=np.ndarray) self._inst_type = inst_types[state["inst_type_str"]] def __repr__(self): """Build string representation of the Spectrum object.""" inst_type_str = _get_instance_type_string(self) # shape & dimension names dims = " × ".join( [f"{dim[0]} {dim[1]}s" for dim in zip(self.shape, self._dims)] ) freq_range = f"{self.freqs[0]:0.1f}-{self.freqs[-1]:0.1f} Hz" return ( f"<{self._data_type} (from {inst_type_str}, " f"{self.method} method) | {dims}, {freq_range}>" ) @repr_html def _repr_html_(self, caption=None): """Build HTML representation of the Spectrum object.""" inst_type_str = _get_instance_type_string(self) units = [f"{ch_type}: {unit}" for ch_type, unit in self.units().items()] t = _get_html_template("repr", "spectrum.html.jinja") t = t.render(spectrum=self, inst_type=inst_type_str, units=units) return t def _check_values(self): """Check PSD results for correct shape and bad values.""" assert len(self._dims) == self._data.ndim, (self._dims, self._data.ndim) assert self._data.shape == self._shape # TODO: should this be more fine-grained (report "chan X in epoch Y")? ch_dim = self._dims.index("channel") dims = list(range(self._data.ndim)) dims.pop(ch_dim) # take min() across all but the channel axis # (if the abs becomes memory intensive we could iterate over channels) use_data = self._data if use_data.dtype.kind == "c": use_data = np.abs(use_data) bad_value = use_data.min(axis=tuple(dims)) == 0 bad_value &= ~np.isin(self.ch_names, self.info["bads"]) if bad_value.any(): chs = np.array(self.ch_names)[bad_value].tolist() s = _pl(bad_value.sum()) warn(f'Zero value in spectrum for channel{s} {", ".join(chs)}', UserWarning) def _returns_complex_tapers(self, **method_kw): return self.method == "multitaper" and method_kw.get("output") == "complex" def _compute_spectra(self, data, fmin, fmax, n_jobs, method_kw, verbose): # make the spectra result = self._psd_func( data, self.sfreq, fmin=fmin, fmax=fmax, n_jobs=n_jobs, verbose=verbose ) # assign ._data (handling unaggregated multitaper output) if self._returns_complex_tapers(**method_kw): fourier_coefs, freqs, weights = result self._data = fourier_coefs self._weights = weights else: psds, freqs = result self._data = psds self._weights = None # assign properties (._data already assigned above) self._freqs = freqs # this is *expected* shape, it gets asserted later in _check_values() # (and then deleted afterwards) self._shape = (len(self.ch_names), len(self.freqs)) # append n_welch_segments (use "" as .get() default since None considered valid) if method_kw.get("average", "") in (None, False): n_welch_segments = _compute_n_welch_segments(data.shape[-1], method_kw) self._shape += (n_welch_segments,) # insert n_tapers if self._returns_complex_tapers(**method_kw): self._shape = self._shape[:-1] + (self._weights.size,) + self._shape[-1:] # we don't need these anymore, and they make save/load harder del self._picks del self._psd_func del self._time_mask @property def _detrend_picks(self): """Provide compatibility with __iter__.""" return list() @property def ch_names(self): return self.info["ch_names"] @property def data(self): return self._data @property def freqs(self): return self._freqs @property def method(self): return self._method @property def nave(self): return self._nave @property def weights(self): return self._weights @property def sfreq(self): return self._sfreq @property def shape(self): return self._data.shape def copy(self): """Return copy of the Spectrum instance. Returns ------- spectrum : instance of Spectrum A copy of the object. """ return deepcopy(self) @fill_doc def get_data( self, picks=None, exclude="bads", fmin=0, fmax=np.inf, return_freqs=False ): """Get spectrum data in NumPy array format. Parameters ---------- %(picks_good_data_noref)s %(exclude_spectrum_get_data)s %(fmin_fmax_psd)s return_freqs : bool Whether to return the frequency bin values for the requested frequency range. Default is ``False``. Returns ------- data : array The requested data in a NumPy array. freqs : array The frequency values for the requested range. Only returned if ``return_freqs`` is ``True``. """ picks = _picks_to_idx( self.info, picks, "data_or_ica", exclude=exclude, with_ref_meg=False ) fmin_idx = np.searchsorted(self.freqs, fmin) fmax_idx = np.searchsorted(self.freqs, fmax, side="right") freq_picks = np.arange(fmin_idx, fmax_idx) freq_axis = self._dims.index("freq") chan_axis = self._dims.index("channel") # normally there's a risk of np.take reducing array dimension if there # were only one channel or frequency selected, but `_picks_to_idx` # always returns an array of picks, and np.arange always returns an # array of freq bin indices, so we're safe; the result will always be # 2D. data = self._data.take(picks, chan_axis).take(freq_picks, freq_axis) if return_freqs: freqs = self._freqs[fmin_idx:fmax_idx] return (data, freqs) return data @fill_doc def plot( self, *, picks=None, average=False, dB=True, amplitude=False, xscale="linear", ci="sd", ci_alpha=0.3, color="black", alpha=None, spatial_colors=True, sphere=None, exclude=(), axes=None, show=True, ): """%(plot_psd_doc)s. Parameters ---------- %(picks_all_data_noref)s .. versionchanged:: 1.5 In version 1.5, the default behavior changed so that all :term:`data channels` (not just "good" data channels) are shown by default. average : bool Whether to average across channels before plotting. If ``True``, interactive plotting of scalp topography is disabled, and parameters ``ci`` and ``ci_alpha`` control the style of the confidence band around the mean. Default is ``False``. %(dB_spectrum_plot)s amplitude : bool Whether to plot an amplitude spectrum (``True``) or power spectrum (``False``). .. versionchanged:: 1.8 In version 1.8, the default changed to ``amplitude=False``. %(xscale_plot_psd)s ci : float | 'sd' | 'range' | None Type of confidence band drawn around the mean when ``average=True``. If ``'sd'`` the band spans ±1 standard deviation across channels. If ``'range'`` the band spans the range across channels at each frequency. If a :class:`float`, it indicates the (bootstrapped) confidence interval to display, and must satisfy ``0 < ci <= 100``. If ``None``, no band is drawn. Default is ``sd``. ci_alpha : float Opacity of the confidence band. Must satisfy ``0 <= ci_alpha <= 1``. Default is 0.3. %(color_plot_psd)s alpha : float | None Opacity of the spectrum line(s). If :class:`float`, must satisfy ``0 <= alpha <= 1``. If ``None``, opacity will be ``1`` when ``average=True`` and ``0.1`` when ``average=False``. Default is ``None``. %(spatial_colors_psd)s %(sphere_topomap_auto)s %(exclude_spectrum_plot)s .. versionchanged:: 1.5 In version 1.5, the default behavior changed from ``exclude='bads'`` to ``exclude=()``. %(axes_spectrum_plot_topomap)s %(show)s Returns ------- fig : instance of matplotlib.figure.Figure Figure with spectra plotted in separate subplots for each channel type. """ # Must nest this _mpl_figure import because of the BACKEND global # stuff from ..viz._mpl_figure import _line_figure, _split_picks_by_type # arg checking ci = _check_ci(ci) _check_option("xscale", xscale, ("log", "linear")) sphere = _check_sphere(sphere, self.info) # defaults scalings = _handle_default("scalings", None) titles = _handle_default("titles", None) units = _handle_default("units", None) _validate_type(amplitude, bool, "amplitude") estimate = "amplitude" if amplitude else "power" logger.info(f"Plotting {estimate} spectral density ({dB=}).") # split picks by channel type picks = _picks_to_idx( self.info, picks, "data", exclude=exclude, with_ref_meg=False ) (picks_list, units_list, scalings_list, titles_list) = _split_picks_by_type( self, picks, units, scalings, titles ) # prepare data (e.g. aggregate across dims, convert complex to power) psd_list = [ self._prepare_data_for_plot( self._data.take(_p, axis=self._dims.index("channel")) ) for _p in picks_list ] # initialize figure fig, axes = _line_figure(self, axes, picks=picks) # don't add ylabels & titles if figure has unexpected number of axes make_label = len(axes) == len(fig.axes) # Plot Frequency [Hz] xlabel only on the last axis xlabels_list = [False] * (len(axes) - 1) + [True] # plot _plot_psd( self, fig, self.freqs, psd_list, picks_list, titles_list, units_list, scalings_list, axes, make_label, color, area_mode=ci, area_alpha=ci_alpha, dB=dB, estimate=estimate, average=average, spatial_colors=spatial_colors, xscale=xscale, line_alpha=alpha, sphere=sphere, xlabels_list=xlabels_list, ) plt_show(show, fig) return fig @fill_doc def plot_topo( self, *, dB=True, layout=None, color="w", fig_facecolor="k", axis_facecolor="k", axes=None, block=False, show=True, ): """Plot power spectral density, separately for each channel. Parameters ---------- %(dB_spectrum_plot_topo)s %(layout_spectrum_plot_topo)s %(color_spectrum_plot_topo)s %(fig_facecolor)s %(axis_facecolor)s %(axes_spectrum_plot_topo)s %(block)s %(show)s Returns ------- fig : instance of matplotlib.figure.Figure Figure distributing one image per channel across sensor topography. """ if layout is None: layout = find_layout(self.info) psds, freqs = self.get_data(return_freqs=True) # prepare data (e.g. aggregate across dims, convert complex to power) psds = self._prepare_data_for_plot(psds) if dB: psds = 10 * np.log10(psds) y_label = "dB" else: y_label = "Power" show_func = partial( _plot_timeseries_unified, data=[psds], color=color, times=[freqs] ) click_func = partial(_plot_timeseries, data=[psds], color=color, times=[freqs]) picks = _pick_data_channels(self.info) info = pick_info(self.info, picks) fig = _plot_topo( info, times=freqs, show_func=show_func, click_func=click_func, layout=layout, axis_facecolor=axis_facecolor, fig_facecolor=fig_facecolor, x_label="Frequency (Hz)", unified=True, y_label=y_label, axes=axes, ) plt_show(show, block=block) return fig @fill_doc def plot_topomap( self, bands=None, ch_type=None, *, normalize=False, agg_fun=None, dB=False, sensors=True, show_names=False, mask=None, mask_params=None, contours=6, outlines="head", sphere=None, image_interp=_INTERPOLATION_DEFAULT, extrapolate=_EXTRAPOLATE_DEFAULT, border=_BORDER_DEFAULT, res=64, size=1, cmap=None, vlim=(None, None), cnorm=None, colorbar=True, cbar_fmt="auto", units=None, axes=None, show=True, ): """Plot scalp topography of PSD for chosen frequency bands. Parameters ---------- %(bands_psd_topo)s %(ch_type_topomap_psd)s %(normalize_psd_topo)s %(agg_fun_psd_topo)s %(dB_plot_topomap)s %(sensors_topomap)s %(show_names_topomap)s %(mask_evoked_topomap)s %(mask_params_topomap)s %(contours_topomap)s %(outlines_topomap)s %(sphere_topomap_auto)s %(image_interp_topomap)s %(extrapolate_topomap)s %(border_topomap)s %(res_topomap)s %(size_topomap)s %(cmap_topomap)s %(vlim_plot_topomap_psd)s %(cnorm)s %(colorbar_topomap)s %(cbar_fmt_topomap_psd)s %(units_topomap)s %(axes_spectrum_plot_topomap)s %(show)s Returns ------- fig : instance of Figure Figure showing one scalp topography per frequency band. """ ch_type = _get_plot_ch_type(self, ch_type) if units is None: units = _handle_default("units", None) unit = units[ch_type] if hasattr(units, "keys") else units scalings = _handle_default("scalings", None) scaling = scalings[ch_type] ( picks, pos, merge_channels, names, ch_type, sphere, clip_origin, ) = _prepare_topomap_plot(self, ch_type, sphere=sphere) outlines = _make_head_outlines(sphere, pos, outlines, clip_origin) psds, freqs = self.get_data(picks=picks, return_freqs=True) # prepare data (e.g. aggregate across dims, convert complex to power) psds = self._prepare_data_for_plot(psds) psds *= scaling**2 if merge_channels: psds, names = _merge_ch_data(psds, ch_type, names, method="mean") names = _prepare_sensor_names(names, show_names) return plot_psds_topomap( psds=psds, freqs=freqs, pos=pos, bands=bands, ch_type=ch_type, normalize=normalize, agg_fun=agg_fun, dB=dB, sensors=sensors, names=names, mask=mask, mask_params=mask_params, contours=contours, outlines=outlines, sphere=sphere, image_interp=image_interp, extrapolate=extrapolate, border=border, res=res, size=size, cmap=cmap, vlim=vlim, cnorm=cnorm, colorbar=colorbar, cbar_fmt=cbar_fmt, unit=unit, axes=axes, show=show, ) def _prepare_data_for_plot(self, data): # handle unaggregated Welch if "segment" in self._dims: logger.info("Aggregating Welch estimates (median) before plotting...") data = np.nanmedian(data, axis=self._dims.index("segment")) # handle unaggregated multitaper (also handles complex -> power) elif "taper" in self._dims: logger.info("Aggregating multitaper estimates before plotting...") data = _psd_from_mt(data, self.weights) # handle complex data (should only be Welch remaining) if np.iscomplexobj(data): data = (data * data.conj()).real # Scaling may be slightly off # handle epochs if "epoch" in self._dims: # XXX TODO FIXME decide how to properly aggregate across repeated # measures (epochs) and non-repeated but correlated measures # (channels) when calculating stddev or a CI. For across-channel # aggregation, doi:10.1007/s10162-012-0321-8 used hotellings T**2 # with a correction factor that estimated data rank using monte # carlo simulations; seems like we could use our own data rank # estimation methods to similar effect. Their exact approach used # complex spectra though, here we've already converted to power; # not sure if that makes an important difference? Anyway that # aggregation would need to happen in the _plot_psd function # though, not here... for now we just average like we always did. # only log message if averaging will actually have an effect if data.shape[0] > 1: logger.info("Averaging across epochs before plotting...") # epoch axis should always be the first axis data = data.mean(axis=0) return data @verbose def save(self, fname, *, overwrite=False, verbose=None): """Save spectrum data to disk (in HDF5 format). Parameters ---------- fname : path-like Path of file to save to. %(overwrite)s %(verbose)s See Also -------- mne.time_frequency.read_spectrum """ _, write_hdf5 = _import_h5io_funcs() check_fname(fname, "spectrum", (".h5", ".hdf5")) fname = _check_fname(fname, overwrite=overwrite, verbose=verbose) out = self.__getstate__() write_hdf5(fname, out, overwrite=overwrite, title="mnepython") @verbose def to_data_frame( self, picks=None, index=None, copy=True, long_format=False, *, verbose=None ): """Export data in tabular structure as a pandas DataFrame. Channels are converted to columns in the DataFrame. By default, an additional column "freq" is added, unless ``index='freq'`` (in which case frequency values form the DataFrame's index). Parameters ---------- %(picks_all)s index : str | list of str | None Kind of index to use for the DataFrame. If ``None``, a sequential integer index (:class:`pandas.RangeIndex`) will be used. If a :class:`str`, a :class:`pandas.Index` will be used (see Notes). If a list of two or more string values, a :class:`pandas.MultiIndex` will be used. Defaults to ``None``. %(copy_df)s %(long_format_df_spe)s %(verbose)s Returns ------- %(df_return)s Notes ----- Valid values for ``index`` depend on whether the Spectrum was created from continuous data (:class:`~mne.io.Raw`, :class:`~mne.Evoked`) or discontinuous data (:class:`~mne.Epochs`). For continuous data, only ``None`` or ``'freq'`` is supported. For discontinuous data, additional valid values are ``'epoch'`` and ``'condition'``, or a :class:`list` comprising some of the valid string values (e.g., ``['freq', 'epoch']``). """ # check pandas once here, instead of in each private utils function pd = _check_pandas_installed() # noqa # triage for Epoch-derived or unaggregated spectra from_epo = _get_instance_type_string(self) == "Epochs" unagg_welch = "segment" in self._dims unagg_mt = "taper" in self._dims # arg checking valid_index_args = ["freq"] if from_epo: valid_index_args += ["epoch", "condition"] index = _check_pandas_index_arguments(index, valid_index_args) # get data picks = _picks_to_idx(self.info, picks, "all", exclude=()) data = self.get_data(picks) if copy: data = data.copy() # reshape if unagg_mt: data = np.moveaxis(data, self._dims.index("freq"), -2) if from_epo: n_epochs, n_picks, n_freqs = data.shape[:3] else: n_epochs, n_picks, n_freqs = (1,) + data.shape[:2] n_segs = data.shape[-1] if unagg_mt or unagg_welch else 1 data = np.moveaxis(data, self._dims.index("channel"), -1) # at this point, should be ([epoch], freq, [segment/taper], channel) data = data.reshape(n_epochs * n_freqs * n_segs, n_picks) # prepare extra columns / multiindex mindex = list() default_index = list() if from_epo: rev_event_id = {v: k for k, v in self.event_id.items()} _conds = [rev_event_id[k] for k in self.events[:, 2]] conditions = np.repeat(_conds, n_freqs * n_segs) epoch_nums = np.repeat(self.selection, n_freqs * n_segs) mindex.extend([("condition", conditions), ("epoch", epoch_nums)]) default_index.extend(["condition", "epoch"]) freqs = np.tile(np.repeat(self.freqs, n_segs), n_epochs) mindex.append(("freq", freqs)) default_index.append("freq") if unagg_mt or unagg_welch: name = "taper" if unagg_mt else "segment" seg_nums = np.tile(np.arange(n_segs), n_epochs * n_freqs) mindex.append((name, seg_nums)) default_index.append(name) # build DataFrame df = _build_data_frame( self, data, picks, long_format, mindex, index, default_index=default_index ) return df def units(self, latex=False): """Get the spectrum units for each channel type. Parameters ---------- latex : bool Whether to format the unit strings as LaTeX. Default is ``False``. Returns ------- units : dict Mapping from channel type to a string representation of the units for that channel type. """ units = _handle_default("si_units", None) return { ch_type: _format_units_psd(units[ch_type], power=True, latex=latex) for ch_type in sorted(self.get_channel_types(unique=True)) } @fill_doc class Spectrum(BaseSpectrum): """Data object for spectral representations of continuous data. .. warning:: The preferred means of creating Spectrum objects from continuous or averaged data is via the instance methods :meth:`mne.io.Raw.compute_psd` or :meth:`mne.Evoked.compute_psd`. Direct class instantiation is not supported. Parameters ---------- inst : instance of Raw or Evoked The data from which to compute the frequency spectrum. %(method_psd_auto)s ``'auto'`` (default) uses Welch's method for continuous data and multitaper for :class:`~mne.Evoked` data. %(fmin_fmax_psd)s %(tmin_tmax_psd)s %(picks_good_data_noref)s %(exclude_psd)s %(proj_psd)s %(remove_dc)s %(reject_by_annotation_psd)s %(n_jobs)s %(verbose)s %(method_kw_psd)s Attributes ---------- ch_names : list The channel names. freqs : array Frequencies at which the amplitude, power, or fourier coefficients have been computed. %(info_not_none)s method : ``'welch'``| ``'multitaper'`` The method used to compute the spectrum. nave : int | None The number of trials averaged together when generating the spectrum. ``None`` indicates no averaging is known to have occurred. weights : array | None The weights for each taper. Only present if spectra computed with ``method='multitaper'`` and ``output='complex'``. .. versionadded:: 1.8 See Also -------- EpochsSpectrum SpectrumArray mne.io.Raw.compute_psd mne.Epochs.compute_psd mne.Evoked.compute_psd References ---------- .. footbibliography:: """ def __init__( self, inst, method, fmin, fmax, tmin, tmax, picks, exclude, proj, remove_dc, reject_by_annotation, *, n_jobs, verbose=None, **method_kw, ): from ..io import BaseRaw # triage reading from file if isinstance(inst, dict): self.__setstate__(inst) return # do the basic setup super().__init__( inst, method, fmin, fmax, tmin, tmax, picks, exclude, proj, remove_dc, n_jobs=n_jobs, verbose=verbose, **method_kw, ) # get just the data we want if isinstance(self.inst, BaseRaw): start, stop = np.where(self._time_mask)[0][[0, -1]] rba = "NaN" if reject_by_annotation else None data = self.inst.get_data( self._picks, start, stop + 1, reject_by_annotation=rba ) if np.any(np.isnan(data)) and method == "multitaper": raise NotImplementedError( 'Cannot use method="multitaper" when reject_by_annotation=True. ' 'Please use method="welch" instead.' ) else: # Evoked data = self.inst.data[self._picks][:, self._time_mask] # set nave self._nave = getattr(inst, "nave", None) # compute the spectra self._compute_spectra(data, fmin, fmax, n_jobs, method_kw, verbose) # check for correct shape and bad values self._check_values() del self._shape # calculated from self._data henceforth # save memory del self.inst def __getitem__(self, item): """Get Spectrum data. Parameters ---------- item : int | slice | array-like Indexing is similar to a :class:`NumPy array`; see Notes. Returns ------- %(getitem_spectrum_return)s Notes ----- Integer-, list-, and slice-based indexing is possible: - ``spectrum[0]`` gives all frequency bins in the first channel - ``spectrum[:3]`` gives all frequency bins in the first 3 channels - ``spectrum[[0, 2], 5]`` gives the value in the sixth frequency bin of the first and third channels - ``spectrum[(4, 7)]`` is the same as ``spectrum[4, 7]``. .. note:: Unlike :class:`~mne.io.Raw` objects (which returns a tuple of the requested data values and the corresponding times), accessing :class:`~mne.time_frequency.Spectrum` values via subscript does **not** return the corresponding frequency bin values. If you need them, use ``spectrum.freqs[freq_indices]`` or ``spectrum.get_data(..., return_freqs=True)``. """ from ..io import BaseRaw self._parse_get_set_params = partial(BaseRaw._parse_get_set_params, self) return BaseRaw._getitem(self, item, return_times=False) def _check_data_shape(data, info, freqs, dim_names, weights, is_epoched): if data.ndim != len(dim_names): raise ValueError( f"Expected data to have {len(dim_names)} dimensions, got {data.ndim}." ) allowed_dims = ["epoch", "channel", "freq", "segment", "taper"] if not is_epoched: allowed_dims.remove("epoch") # TODO maybe we should be nice and allow plural versions of each dimname? for dim in dim_names: _check_option("dim_names", dim, allowed_dims) if "channel" not in dim_names or "freq" not in dim_names: raise ValueError("Both 'channel' and 'freq' must be present in `dim_names`.") if list(dim_names).index("channel") != int(is_epoched): raise ValueError( f"'channel' must be the {'second' if is_epoched else 'first'} dimension of " "the data." ) want_n_chan = _pick_data_channels(info).size got_n_chan = data.shape[list(dim_names).index("channel")] if got_n_chan != want_n_chan: raise ValueError( f"The number of channels in `data` ({got_n_chan}) must match the number of " f"good data channels in `info` ({want_n_chan})." ) # given we limit max array size and ensure channel & freq dims present, only one of # taper or segment can be present if "taper" in dim_names: if dim_names[-2] != "taper": # _psd_from_mt assumes this (called when plotting) raise ValueError( "'taper' must be the second to last dimension of the data." ) # expect weights for each taper actual = None if weights is None else weights.size expected = data.shape[list(dim_names).index("taper")] if actual != expected: raise ValueError( f"Expected size of `weights` to be {expected} to match 'n_tapers' in " f"`data`, got {actual}." ) elif "segment" in dim_names and dim_names[-1] != "segment": raise ValueError("'segment' must be the last dimension of the data.") # freq being in wrong position ruled out by above checks want_n_freq = freqs.size got_n_freq = data.shape[list(dim_names).index("freq")] if got_n_freq != want_n_freq: raise ValueError( f"The number of frequencies in `data` ({got_n_freq}) must match the number " f"of elements in `freqs` ({want_n_freq})." ) @fill_doc class SpectrumArray(Spectrum): """Data object for precomputed spectral data (in NumPy array format). Parameters ---------- data : ndarray, shape (n_channels, [n_tapers], n_freqs, [n_segments]) The spectra for each channel. %(info_not_none)s %(freqs_tfr_array)s dim_names : tuple of str The name of the dimensions in the data, in the order they occur. Must contain ``'channel'`` and ``'freq'``; if data are unaggregated estimates, also include either a ``'segment'`` (e.g., Welch-like algorithms) or ``'taper'`` (e.g., multitaper algorithms) dimension. If including ``'taper'``, you should also pass a ``weights`` parameter. .. versionadded:: 1.8 weights : ndarray | None Weights for the ``'taper'`` dimension, if present (see ``dim_names``). .. versionadded:: 1.8 %(verbose)s See Also -------- mne.create_info mne.EvokedArray mne.io.RawArray EpochsSpectrumArray Notes ----- %(notes_spectrum_array)s .. versionadded:: 1.6 """ @verbose def __init__( self, data, info, freqs, dim_names=("channel", "freq"), weights=None, *, verbose=None, ): # (channel, [taper], freq, [segment]) _check_option("data.ndim", data.ndim, (2, 3)) # only allow one extra dimension _check_data_shape(data, info, freqs, dim_names, weights, is_epoched=False) self.__setstate__( dict( method="unknown", data=data, sfreq=info["sfreq"], dims=dim_names, freqs=freqs, inst_type_str="Array", data_type=( "Fourier Coefficients" if np.iscomplexobj(data) else "Power Spectrum" ), info=info, weights=weights, ) ) @fill_doc class EpochsSpectrum(BaseSpectrum, GetEpochsMixin): """Data object for spectral representations of epoched data. .. warning:: The preferred means of creating Spectrum objects from Epochs is via the instance method :meth:`mne.Epochs.compute_psd`. Direct class instantiation is not supported. Parameters ---------- inst : instance of Epochs The data from which to compute the frequency spectrum. %(method_psd)s %(fmin_fmax_psd)s %(tmin_tmax_psd)s %(picks_good_data_noref)s %(exclude_psd)s %(proj_psd)s %(remove_dc)s %(n_jobs)s %(verbose)s %(method_kw_psd)s Attributes ---------- ch_names : list The channel names. freqs : array Frequencies at which the amplitude, power, or fourier coefficients have been computed. %(info_not_none)s method : ``'welch'``| ``'multitaper'`` The method used to compute the spectrum. weights : array | None The weights for each taper. Only present if spectra computed with ``method='multitaper'`` and ``output='complex'``. .. versionadded:: 1.8 See Also -------- EpochsSpectrumArray Spectrum mne.Epochs.compute_psd References ---------- .. footbibliography:: """ def __init__( self, inst, method, fmin, fmax, tmin, tmax, picks, exclude, proj, remove_dc, *, n_jobs, verbose=None, **method_kw, ): # triage reading from file if isinstance(inst, dict): self.__setstate__(inst) return # do the basic setup super().__init__( inst, method, fmin, fmax, tmin, tmax, picks, exclude, proj, remove_dc, n_jobs=n_jobs, verbose=verbose, **method_kw, ) # get just the data we want data = self.inst._get_data(picks=self._picks, on_empty="raise")[ :, :, self._time_mask ] # compute the spectra self._compute_spectra(data, fmin, fmax, n_jobs, method_kw, verbose) self._dims = ("epoch",) + self._dims self._shape = (len(self.inst),) + self._shape # check for correct shape and bad values self._check_values() del self._shape # we need these for to_data_frame() self.event_id = self.inst.event_id.copy() self.events = self.inst.events.copy() self.selection = self.inst.selection.copy() # we need these for __getitem__() self.drop_log = deepcopy(self.inst.drop_log) self._metadata = self.inst.metadata # save memory del self.inst def __getitem__(self, item): """Subselect epochs from an EpochsSpectrum. Parameters ---------- item : int | slice | array-like | str Access options are the same as for :class:`~mne.Epochs` objects, see the docstring of :meth:`mne.Epochs.__getitem__` for explanation. Returns ------- %(getitem_epochspectrum_return)s """ return super().__getitem__(item) def __getstate__(self): """Prepare object for serialization.""" out = super().__getstate__() out.update( metadata=self._metadata, drop_log=self.drop_log, event_id=self.event_id, events=self.events, selection=self.selection, ) return out def __setstate__(self, state): """Unpack from serialized format.""" super().__setstate__(state) self._metadata = state["metadata"] self.drop_log = state["drop_log"] self.event_id = state["event_id"] self.events = state["events"] self.selection = state["selection"] def average(self, method="mean"): """Average the spectra across epochs. Parameters ---------- method : 'mean' | 'median' | callable How to aggregate spectra across epochs. If callable, must take a :class:`NumPy array` of shape ``(n_epochs, n_channels, n_freqs)`` and return an array of shape ``(n_channels, n_freqs)``. Default is ``'mean'``. Returns ------- spectrum : instance of Spectrum The aggregated spectrum object. """ _validate_type(method, ("str", "callable"), "method") method = _make_combine_callable( method, axis=0, valid=("mean", "median"), keepdims=False ) if not callable(method): raise ValueError( '"method" must be a valid string or callable, ' f"got a {type(method).__name__} ({method})." ) # averaging unaggregated spectral estimates are not supported if "segment" in self._dims: raise NotImplementedError( "Averaging individual Welch segments across epochs is not " "supported. Consider averaging the signals before computing " "the Welch spectrum estimates." ) if "taper" in self._dims: raise NotImplementedError( "Averaging multitaper tapers across epochs is not supported. Consider " "averaging the signals before computing the complex spectrum." ) # serialize the object and update data, dims, and data type state = super().__getstate__() state["nave"] = state["data"].shape[0] state["data"] = method(state["data"]) state["dims"] = state["dims"][1:] state["data_type"] = f'Averaged {state["data_type"]}' defaults = dict( method=None, fmin=None, fmax=None, tmin=None, tmax=None, picks=None, exclude=(), proj=None, remove_dc=None, reject_by_annotation=None, n_jobs=None, verbose=None, ) return Spectrum(state, **defaults) @fill_doc class EpochsSpectrumArray(EpochsSpectrum): """Data object for precomputed epoched spectral data (in NumPy array format). Parameters ---------- data : ndarray, shape (n_epochs, n_channels, [n_tapers], n_freqs, [n_segments]) The spectra for each channel in each epoch. %(info_not_none)s %(freqs_tfr_array)s %(events_epochs)s %(event_id)s dim_names : tuple of str The name of the dimensions in the data, in the order they occur. Must contain ``'channel'`` and ``'freq'``; if data are unaggregated estimates, also include either a ``'segment'`` (e.g., Welch-like algorithms) or ``'taper'`` (e.g., multitaper algorithms) dimension. If including ``'taper'``, you should also pass a ``weights`` parameter. .. versionadded:: 1.8 weights : ndarray | None Weights for the ``'taper'`` dimension, if present (see ``dim_names``). .. versionadded:: 1.8 %(verbose)s See Also -------- mne.create_info mne.EpochsArray SpectrumArray Notes ----- %(notes_spectrum_array)s .. versionadded:: 1.6 """ @verbose def __init__( self, data, info, freqs, events=None, event_id=None, dim_names=("epoch", "channel", "freq"), weights=None, *, verbose=None, ): # (epoch, channel, [taper], freq, [segment]) _check_option("data.ndim", data.ndim, (3, 4)) # only allow one extra dimension if list(dim_names).index("epoch") != 0: raise ValueError("'epoch' must be the first dimension of `data`.") if events is not None and data.shape[0] != events.shape[0]: raise ValueError( f"The first dimension of `data` ({data.shape[0]}) must match the first " f"dimension of `events` ({events.shape[0]})." ) _check_data_shape(data, info, freqs, dim_names, weights, is_epoched=True) self.__setstate__( dict( method="unknown", data=data, sfreq=info["sfreq"], dims=dim_names, freqs=freqs, inst_type_str="Array", data_type=( "Fourier Coefficients" if np.iscomplexobj(data) else "Power Spectrum" ), info=info, events=events, event_id=event_id, metadata=None, selection=np.arange(data.shape[0]), drop_log=tuple(tuple() for _ in range(data.shape[0])), weights=weights, ) ) def read_spectrum(fname): """Load a :class:`mne.time_frequency.Spectrum` object from disk. Parameters ---------- fname : path-like Path to a spectrum file in HDF5 format, which should end with ``.h5`` or ``.hdf5``. Returns ------- spectrum : instance of Spectrum The loaded Spectrum object. See Also -------- mne.time_frequency.Spectrum.save """ read_hdf5, _ = _import_h5io_funcs() _validate_type(fname, "path-like", "fname") fname = _check_fname(fname=fname, overwrite="read", must_exist=False) # read it in hdf5_dict = read_hdf5(fname, title="mnepython") defaults = dict( method=None, fmin=None, fmax=None, tmin=None, tmax=None, picks=None, exclude=(), proj=None, remove_dc=None, reject_by_annotation=None, n_jobs=None, verbose=None, ) Klass = EpochsSpectrum if hdf5_dict["inst_type_str"] == "Epochs" else Spectrum return Klass(hdf5_dict, **defaults) def _check_ci(ci): ci = "sd" if ci == "std" else ci # be forgiving if _is_numeric(ci): if not (0 < ci <= 100): raise ValueError(f"ci must satisfy 0 < ci <= 100, got {ci}") ci /= 100.0 else: _check_option("ci", ci, [None, "sd", "range"]) return ci def _compute_n_welch_segments(n_times, method_kw): # get default values from psd_array_welch _defaults = dict() for param in ("n_fft", "n_per_seg", "n_overlap"): _defaults[param] = signature(psd_array_welch).parameters[param].default # override defaults with user-specified values for key, val in _defaults.items(): _defaults.update({key: method_kw.get(key, val)}) # sanity check values / replace `None`s with real numbers n_fft, n_per_seg, n_overlap = _check_nfft(n_times, **_defaults) # compute expected number of segments step = n_per_seg - n_overlap return (n_times - n_overlap) // step def _validate_method(method, instance_type): """Convert 'auto' to a real method name, and validate.""" if method == "auto": method = "welch" if instance_type.startswith("Raw") else "multitaper" _check_option("method", method, ("welch", "multitaper")) return method