1161 lines
42 KiB
Python
1161 lines
42 KiB
Python
"""Functions to plot epochs data."""
|
|
|
|
# Authors: The MNE-Python contributors.
|
|
# License: BSD-3-Clause
|
|
# Copyright the MNE-Python contributors.
|
|
|
|
from collections import Counter
|
|
from copy import deepcopy
|
|
|
|
import numpy as np
|
|
from scipy.ndimage import gaussian_filter1d
|
|
|
|
from .._fiff.meas_info import create_info
|
|
from .._fiff.pick import (
|
|
_DATA_CH_TYPES_SPLIT,
|
|
_VALID_CHANNEL_TYPES,
|
|
_picks_to_idx,
|
|
)
|
|
from ..defaults import _handle_default
|
|
from ..utils import _check_option, fill_doc, legacy, logger, verbose, warn
|
|
from ..utils.spectrum import _split_psd_kwargs
|
|
from .raw import _setup_channel_selections
|
|
from .utils import (
|
|
DraggableColorbar,
|
|
_check_cov,
|
|
_compute_scalings,
|
|
_get_channel_plotting_order,
|
|
_handle_decim,
|
|
_handle_precompute,
|
|
_make_combine_callable,
|
|
_make_event_color_dict,
|
|
_set_title_multiple_electrodes,
|
|
_set_window_title,
|
|
_setup_cmap,
|
|
_setup_vmin_vmax,
|
|
_validate_type,
|
|
plt_show,
|
|
)
|
|
|
|
|
|
@fill_doc
|
|
def plot_epochs_image(
|
|
epochs,
|
|
picks=None,
|
|
sigma=0.0,
|
|
vmin=None,
|
|
vmax=None,
|
|
colorbar=True,
|
|
order=None,
|
|
show=True,
|
|
units=None,
|
|
scalings=None,
|
|
cmap=None,
|
|
fig=None,
|
|
axes=None,
|
|
overlay_times=None,
|
|
combine=None,
|
|
group_by=None,
|
|
evoked=True,
|
|
ts_args=None,
|
|
title=None,
|
|
clear=False,
|
|
):
|
|
"""Plot Event Related Potential / Fields image.
|
|
|
|
Parameters
|
|
----------
|
|
epochs : instance of Epochs
|
|
The epochs.
|
|
%(picks_good_data)s
|
|
``picks`` interacts with ``group_by`` and ``combine`` to determine the
|
|
number of figures generated; see Notes.
|
|
sigma : float
|
|
The standard deviation of a Gaussian smoothing window applied along
|
|
the epochs axis of the image. If 0, no smoothing is applied.
|
|
Defaults to 0.
|
|
vmin : None | float | callable
|
|
The min value in the image (and the ER[P/F]). The unit is µV for
|
|
EEG channels, fT for magnetometers and fT/cm for gradiometers.
|
|
If vmin is None and multiple plots are returned, the limit is
|
|
equalized within channel types.
|
|
Hint: to specify the lower limit of the data, use
|
|
``vmin=lambda data: data.min()``.
|
|
vmax : None | float | callable
|
|
The max value in the image (and the ER[P/F]). The unit is µV for
|
|
EEG channels, fT for magnetometers and fT/cm for gradiometers.
|
|
If vmin is None and multiple plots are returned, the limit is
|
|
equalized within channel types.
|
|
colorbar : bool
|
|
Display or not a colorbar.
|
|
order : None | array of int | callable
|
|
If not ``None``, order is used to reorder the epochs along the y-axis
|
|
of the image. If it is an array of :class:`int`, its length should
|
|
match the number of good epochs. If it is a callable it should accept
|
|
two positional parameters (``times`` and ``data``, where
|
|
``data.shape == (len(good_epochs), len(times))``) and return an
|
|
:class:`array <numpy.ndarray>` of indices that will sort ``data`` along
|
|
its first axis.
|
|
show : bool
|
|
Show figure if True.
|
|
units : dict | None
|
|
The units of the channel types used for axes labels. If None,
|
|
defaults to ``units=dict(eeg='µV', grad='fT/cm', mag='fT')``.
|
|
scalings : dict | None
|
|
The scalings of the channel types to be applied for plotting.
|
|
If None, defaults to ``scalings=dict(eeg=1e6, grad=1e13, mag=1e15,
|
|
eog=1e6)``.
|
|
cmap : None | colormap | (colormap, bool) | 'interactive'
|
|
Colormap. If tuple, the first value indicates the colormap to use and
|
|
the second value is a boolean defining interactivity. In interactive
|
|
mode the colors are adjustable by clicking and dragging the colorbar
|
|
with left and right mouse button. Left mouse button moves the scale up
|
|
and down and right mouse button adjusts the range. Hitting space bar
|
|
resets the scale. Up and down arrows can be used to change the
|
|
colormap. If 'interactive', translates to ('RdBu_r', True).
|
|
If None, "RdBu_r" is used, unless the data is all positive, in which
|
|
case "Reds" is used.
|
|
fig : Figure | None
|
|
:class:`~matplotlib.figure.Figure` instance to draw the image to.
|
|
Figure must contain the correct number of axes for drawing the epochs
|
|
image, the evoked response, and a colorbar (depending on values of
|
|
``evoked`` and ``colorbar``). If ``None`` a new figure is created.
|
|
Defaults to ``None``.
|
|
axes : list of Axes | dict of list of Axes | None
|
|
List of :class:`~matplotlib.axes.Axes` objects in which to draw the
|
|
image, evoked response, and colorbar (in that order). Length of list
|
|
must be 1, 2, or 3 (depending on values of ``colorbar`` and ``evoked``
|
|
parameters). If a :class:`dict`, each entry must be a list of Axes
|
|
objects with the same constraints as above. If both ``axes`` and
|
|
``group_by`` are dicts, their keys must match. Providing non-``None``
|
|
values for both ``fig`` and ``axes`` results in an error. Defaults to
|
|
``None``.
|
|
overlay_times : array_like, shape (n_epochs,) | None
|
|
Times (in seconds) at which to draw a line on the corresponding row of
|
|
the image (e.g., a reaction time associated with each epoch). Note that
|
|
``overlay_times`` should be ordered to correspond with the
|
|
:class:`~mne.Epochs` object (i.e., ``overlay_times[0]`` corresponds to
|
|
``epochs[0]``, etc).
|
|
%(combine_plot_epochs_image)s
|
|
group_by : None | dict
|
|
Specifies which channels are aggregated into a single figure, with
|
|
aggregation method determined by the ``combine`` parameter. If not
|
|
``None``, one :class:`~matplotlib.figure.Figure` is made per dict
|
|
entry; the dict key will be used as the figure title and the dict
|
|
values must be lists of picks (either channel names or integer indices
|
|
of ``epochs.ch_names``). For example::
|
|
|
|
group_by=dict(Left_ROI=[1, 2, 3, 4], Right_ROI=[5, 6, 7, 8])
|
|
|
|
Note that within a dict entry all channels must have the same type.
|
|
``group_by`` interacts with ``picks`` and ``combine`` to determine the
|
|
number of figures generated; see Notes. Defaults to ``None``.
|
|
evoked : bool
|
|
Draw the ER[P/F] below the image or not.
|
|
ts_args : None | dict
|
|
Arguments passed to a call to `~mne.viz.plot_compare_evokeds` to style
|
|
the evoked plot below the image. Defaults to an empty dictionary,
|
|
meaning `~mne.viz.plot_compare_evokeds` will be called with default
|
|
parameters.
|
|
title : None | str
|
|
If :class:`str`, will be plotted as figure title. Otherwise, the
|
|
title will indicate channel(s) or channel type being plotted. Defaults
|
|
to ``None``.
|
|
clear : bool
|
|
Whether to clear the axes before plotting (if ``fig`` or ``axes`` are
|
|
provided). Defaults to ``False``.
|
|
|
|
Returns
|
|
-------
|
|
figs : list of Figure
|
|
One figure per channel, channel type, or group, depending on values of
|
|
``picks``, ``group_by``, and ``combine``. See Notes.
|
|
|
|
Notes
|
|
-----
|
|
You can control how channels are aggregated into one figure or plotted in
|
|
separate figures through a combination of the ``picks``, ``group_by``, and
|
|
``combine`` parameters. If ``group_by`` is a :class:`dict`, the result is
|
|
one :class:`~matplotlib.figure.Figure` per dictionary key (for any valid
|
|
values of ``picks`` and ``combine``). If ``group_by`` is ``None``, the
|
|
number and content of the figures generated depends on the values of
|
|
``picks`` and ``combine``, as summarized in this table:
|
|
|
|
.. cssclass:: table-bordered
|
|
.. rst-class:: midvalign
|
|
|
|
+----------+----------------------------+------------+-------------------+
|
|
| group_by | picks | combine | result |
|
|
+==========+============================+============+===================+
|
|
| | None, int, list of int, | None, | |
|
|
| dict | ch_name, list of ch_names, | string, or | 1 figure per |
|
|
| | ch_type, list of ch_types | callable | dict key |
|
|
+----------+----------------------------+------------+-------------------+
|
|
| | None, | None, | |
|
|
| | ch_type, | string, or | 1 figure per |
|
|
| | list of ch_types | callable | ch_type |
|
|
| None +----------------------------+------------+-------------------+
|
|
| | int, | None | 1 figure per pick |
|
|
| | ch_name, +------------+-------------------+
|
|
| | list of int, | string or | 1 figure |
|
|
| | list of ch_names | callable | |
|
|
+----------+----------------------------+------------+-------------------+
|
|
"""
|
|
from ..epochs import EpochsArray
|
|
|
|
_validate_type(group_by, (dict, None), "group_by")
|
|
|
|
units = _handle_default("units", units)
|
|
scalings = _handle_default("scalings", scalings)
|
|
if set(units) != set(scalings):
|
|
raise ValueError("Scalings and units must have the same keys.")
|
|
|
|
# is picks a channel type (or None)?
|
|
picks, picked_types = _picks_to_idx(epochs.info, picks, return_kind=True)
|
|
ch_types = epochs.info.get_channel_types(picks)
|
|
|
|
# `combine` defaults to 'gfp' unless picks are specific channels and
|
|
# there was no group_by passed
|
|
combine_given = combine is not None
|
|
if combine is None and (group_by is not None or picked_types):
|
|
combine = "gfp"
|
|
# convert `combine` into callable (if None or str)
|
|
combine_func = _make_combine_callable(combine)
|
|
|
|
# handle ts_args (params for the evoked time series)
|
|
ts_args = dict() if ts_args is None else ts_args
|
|
manual_ylims = "ylim" in ts_args
|
|
if combine is not None:
|
|
ts_args["show_sensors"] = False
|
|
vlines = [0] if (epochs.times[0] < 0 < epochs.times[-1]) else []
|
|
ts_defaults = dict(
|
|
colors={"cond": "k"},
|
|
title="",
|
|
show=False,
|
|
truncate_yaxis=False,
|
|
truncate_xaxis=False,
|
|
vlines=vlines,
|
|
legend=False,
|
|
)
|
|
ts_defaults.update(**ts_args)
|
|
ts_args = ts_defaults.copy()
|
|
|
|
# construct a group_by dict if one wasn't supplied
|
|
if group_by is None:
|
|
if picked_types:
|
|
# one fig per ch_type
|
|
group_by = {
|
|
ch_type: picks[np.array(ch_types) == ch_type]
|
|
for ch_type in set(ch_types)
|
|
if ch_type in _DATA_CH_TYPES_SPLIT + ("ref_meg",)
|
|
}
|
|
elif combine is None:
|
|
# one fig per pick
|
|
group_by = {epochs.ch_names[pick]: [pick] for pick in picks}
|
|
else:
|
|
# one fig to rule them all
|
|
ch_names = np.array(epochs.ch_names)[picks].tolist()
|
|
key = _set_title_multiple_electrodes(None, combine, ch_names)
|
|
group_by = {key: picks}
|
|
else:
|
|
group_by = deepcopy(group_by)
|
|
# check for heterogeneous sensor type combinations / "combining" 1 channel
|
|
for this_group, these_picks in group_by.items():
|
|
this_ch_type = np.array(ch_types)[np.isin(picks, these_picks)]
|
|
if len(set(this_ch_type)) > 1:
|
|
types = ", ".join(set(this_ch_type))
|
|
raise ValueError(
|
|
f'Cannot combine sensors of different types; "{this_group}" contains '
|
|
f"types {types}."
|
|
)
|
|
# now we know they're all the same type...
|
|
group_by[this_group] = dict(
|
|
picks=these_picks, ch_type=this_ch_type[0], title=title
|
|
)
|
|
|
|
# are they trying to combine a single channel?
|
|
if len(these_picks) < 2 and combine_given:
|
|
warn(
|
|
f'Only one channel in group "{this_group}"; cannot combine by method '
|
|
f'"{combine}".'
|
|
)
|
|
|
|
# check for compatible `fig` / `axes`; instantiate figs if needed; add
|
|
# fig(s) and axes into group_by
|
|
needs_colorbar = colorbar and (axes is not None or fig is not None)
|
|
group_by = _validate_fig_and_axes(
|
|
fig, axes, group_by, evoked, colorbar=needs_colorbar, clear=clear
|
|
)
|
|
del fig, axes, needs_colorbar, clear
|
|
|
|
# prepare images in advance to get consistent vmin/vmax.
|
|
# At the same time, create a subsetted epochs object for each group
|
|
data = epochs._get_data(on_empty="raise")
|
|
vmin_vmax = {ch_type: dict(images=list(), norm=list()) for ch_type in set(ch_types)}
|
|
for this_group, this_group_dict in group_by.items():
|
|
these_picks = this_group_dict["picks"]
|
|
this_ch_type = this_group_dict["ch_type"]
|
|
this_ch_info = [epochs.info["chs"][n] for n in these_picks]
|
|
these_ch_names = np.array(epochs.info["ch_names"])[these_picks]
|
|
this_data = data[:, these_picks]
|
|
# create subsetted epochs object
|
|
this_info = create_info(
|
|
sfreq=epochs.info["sfreq"],
|
|
ch_names=list(these_ch_names),
|
|
ch_types=[this_ch_type] * len(these_picks),
|
|
)
|
|
with this_info._unlock():
|
|
this_info["chs"] = this_ch_info
|
|
this_epochs = EpochsArray(this_data, this_info, tmin=epochs.times[0])
|
|
# apply scalings (only to image, not epochs object), combine channels
|
|
this_image = combine_func(this_data * scalings[this_ch_type])
|
|
# handle `order`. NB: this can potentially yield different orderings
|
|
# in each figure!
|
|
this_image, _overlay_times = _order_epochs(
|
|
this_image, epochs.times, order, overlay_times
|
|
)
|
|
this_norm = np.all(this_image > 0)
|
|
# apply smoothing
|
|
if sigma > 0.0:
|
|
this_image = gaussian_filter1d(
|
|
this_image, sigma=sigma, axis=0, mode="nearest"
|
|
)
|
|
# update the group_by and vmin_vmax dicts
|
|
group_by[this_group].update(
|
|
image=this_image, epochs=this_epochs, norm=this_norm
|
|
)
|
|
vmin_vmax[this_ch_type]["images"].append(this_image)
|
|
vmin_vmax[this_ch_type]["norm"].append(this_norm)
|
|
|
|
# compute overall vmin/vmax for images
|
|
for ch_type, this_vmin_vmax_dict in vmin_vmax.items():
|
|
image_list = this_vmin_vmax_dict["images"]
|
|
image_stack = np.stack(image_list)
|
|
norm = all(this_vmin_vmax_dict["norm"])
|
|
vmin_vmax[ch_type] = _setup_vmin_vmax(image_stack, vmin, vmax, norm)
|
|
del image_stack, vmin, vmax
|
|
|
|
# prepare to plot
|
|
auto_ylims = {ch_type: [0.0, 0.0] for ch_type in set(ch_types)}
|
|
|
|
# plot
|
|
for this_group, this_group_dict in group_by.items():
|
|
this_ch_type = this_group_dict["ch_type"]
|
|
this_axes_dict = this_group_dict["axes"]
|
|
vmin, vmax = vmin_vmax[this_ch_type]
|
|
|
|
# plot title
|
|
if this_group_dict["title"] is None:
|
|
title = _handle_default("titles").get(this_group, this_group)
|
|
if isinstance(combine, str) and len(title):
|
|
_comb = combine.upper() if combine == "gfp" else combine
|
|
_comb = "std. dev." if _comb == "std" else _comb
|
|
title += f" ({_comb})"
|
|
|
|
# plot the image
|
|
this_fig = _plot_epochs_image(
|
|
this_group_dict["image"],
|
|
epochs=this_group_dict["epochs"],
|
|
picks=picks,
|
|
colorbar=colorbar,
|
|
vmin=vmin,
|
|
vmax=vmax,
|
|
cmap=cmap,
|
|
style_axes=True,
|
|
norm=this_group_dict["norm"],
|
|
unit=units[this_ch_type],
|
|
ax=this_axes_dict,
|
|
show=False,
|
|
title=title,
|
|
combine=combine,
|
|
combine_given=combine_given,
|
|
overlay_times=_overlay_times,
|
|
evoked=evoked,
|
|
ts_args=ts_args,
|
|
)
|
|
group_by[this_group].update(fig=this_fig)
|
|
|
|
# detect ylims across figures
|
|
if evoked and not manual_ylims:
|
|
# ensure get_ylim works properly
|
|
this_axes_dict["evoked"].figure.canvas.draw_idle()
|
|
this_bot, this_top = this_axes_dict["evoked"].get_ylim()
|
|
this_min = min(this_bot, this_top)
|
|
this_max = max(this_bot, this_top)
|
|
curr_min, curr_max = auto_ylims[ch_type]
|
|
auto_ylims[this_ch_type] = [
|
|
min(curr_min, this_min),
|
|
max(curr_max, this_max),
|
|
]
|
|
|
|
# equalize ylims across figures (does not adjust ticks)
|
|
if evoked:
|
|
for this_group_dict in group_by.values():
|
|
ax = this_group_dict["axes"]["evoked"]
|
|
ch_type = this_group_dict["ch_type"]
|
|
if not manual_ylims:
|
|
args = auto_ylims[ch_type]
|
|
if "invert_y" in ts_args:
|
|
args = args[::-1]
|
|
ax.set_ylim(*args)
|
|
plt_show(show)
|
|
# impose deterministic order of returned objects
|
|
return_order = np.array(sorted(group_by))
|
|
are_ch_types = np.isin(return_order, _VALID_CHANNEL_TYPES)
|
|
if any(are_ch_types):
|
|
return_order = np.concatenate(
|
|
(return_order[are_ch_types], return_order[~are_ch_types])
|
|
)
|
|
return [group_by[group]["fig"] for group in return_order]
|
|
|
|
|
|
def _validate_fig_and_axes(fig, axes, group_by, evoked, colorbar, clear=False):
|
|
"""Check user-provided fig/axes compatibility with plot_epochs_image."""
|
|
from matplotlib.pyplot import Axes, figure, subplot2grid
|
|
|
|
n_axes = 1 + int(evoked) + int(colorbar)
|
|
ax_names = ("image", "evoked", "colorbar")
|
|
ax_names = np.array(ax_names)[np.where([True, evoked, colorbar])]
|
|
prefix = f"Since evoked={evoked} and colorbar={colorbar}, "
|
|
|
|
# got both fig and axes
|
|
if fig is not None and axes is not None:
|
|
raise ValueError(
|
|
f'At least one of "fig" or "axes" must be None; got fig={fig}, axes={axes}.'
|
|
)
|
|
|
|
# got fig=None and axes=None: make fig(s) and axes
|
|
if fig is None and axes is None:
|
|
axes = dict()
|
|
colspan = 9 if colorbar else 10
|
|
rowspan = 2 if evoked else 3
|
|
shape = (3, 10)
|
|
for this_group in group_by:
|
|
this_fig = figure(layout="constrained")
|
|
_set_window_title(this_fig, this_group)
|
|
subplot2grid(shape, (0, 0), colspan=colspan, rowspan=rowspan, fig=this_fig)
|
|
if evoked:
|
|
subplot2grid(shape, (2, 0), colspan=colspan, rowspan=1, fig=this_fig)
|
|
if colorbar:
|
|
subplot2grid(shape, (0, 9), colspan=1, rowspan=rowspan, fig=this_fig)
|
|
axes[this_group] = this_fig.axes
|
|
|
|
# got a Figure instance
|
|
if fig is not None:
|
|
# If we're re-plotting into a fig made by a previous call to
|
|
# `plot_image`, be forgiving of presence/absence of sensor inset axis.
|
|
if len(fig.axes) not in (n_axes, n_axes + 1):
|
|
raise ValueError(
|
|
f'{prefix}"fig" must contain {n_axes} axes, got {len(fig.axes)}.'
|
|
)
|
|
if len(list(group_by)) != 1:
|
|
raise ValueError(
|
|
'When "fig" is not None, "group_by" can only '
|
|
"have one group (got {}: {}).".format(
|
|
len(group_by), ", ".join(group_by)
|
|
)
|
|
)
|
|
key = list(group_by)[0]
|
|
if clear: # necessary if re-plotting into previous figure
|
|
_ = [ax.clear() for ax in fig.axes]
|
|
if len(fig.axes) > n_axes: # get rid of sensor inset
|
|
fig.axes[-1].remove()
|
|
_set_window_title(fig, key)
|
|
axes = {key: fig.axes}
|
|
|
|
# got an Axes instance, be forgiving (if evoked and colorbar are False)
|
|
if isinstance(axes, Axes):
|
|
axes = [axes]
|
|
|
|
# got an ndarray; be forgiving
|
|
if isinstance(axes, np.ndarray):
|
|
axes = axes.ravel().tolist()
|
|
|
|
# got a list of axes, make it a dict
|
|
if isinstance(axes, list):
|
|
if len(axes) != n_axes:
|
|
raise ValueError(
|
|
f'{prefix}"axes" must be length {n_axes}, got {len(axes)}.'
|
|
)
|
|
# for list of axes to work, must be only one group
|
|
if len(list(group_by)) != 1:
|
|
raise ValueError(
|
|
"When axes is a list, can only plot one group "
|
|
"(got {} groups: {}).".format(len(group_by), ", ".join(group_by))
|
|
)
|
|
key = list(group_by)[0]
|
|
axes = {key: axes}
|
|
|
|
# got a dict of lists of axes, make it dict of dicts
|
|
if isinstance(axes, dict):
|
|
# in theory a user could pass a dict of axes but *NOT* pass a group_by
|
|
# dict, but that is forbidden in the docstring so it shouldn't happen.
|
|
# The next test could fail in that case because we've constructed a
|
|
# group_by dict and the user won't have known what keys we chose.
|
|
if set(axes) != set(group_by):
|
|
raise ValueError(
|
|
f'If "axes" is a dict its keys ({list(axes)}) must match the keys in '
|
|
f'"group_by" ({list(group_by)}).'
|
|
)
|
|
for this_group, this_axes_list in axes.items():
|
|
if len(this_axes_list) != n_axes:
|
|
raise ValueError(
|
|
f'{prefix}each value in "axes" must be a list of {n_axes} axes, got'
|
|
f" {len(this_axes_list)}."
|
|
)
|
|
# NB: next line assumes all axes in each list are in same figure
|
|
group_by[this_group]["fig"] = this_axes_list[0].get_figure()
|
|
group_by[this_group]["axes"] = {
|
|
key: axis for key, axis in zip(ax_names, this_axes_list)
|
|
}
|
|
return group_by
|
|
|
|
|
|
def _order_epochs(data, times, order=None, overlay_times=None):
|
|
"""Sort epochs image data (2D). Helper for plot_epochs_image."""
|
|
n_epochs = len(data)
|
|
|
|
if overlay_times is not None:
|
|
if len(overlay_times) != n_epochs:
|
|
raise ValueError(
|
|
f"size of overlay_times parameter ({len(overlay_times)}) does "
|
|
f"not match the number of epochs ({n_epochs})."
|
|
)
|
|
overlay_times = np.array(overlay_times)
|
|
times_min = np.min(overlay_times)
|
|
times_max = np.max(overlay_times)
|
|
if (times_min < times[0]) or (times_max > times[-1]):
|
|
warn(
|
|
"Some values in overlay_times fall outside of the epochs "
|
|
f"time interval (between {times[0]} s and {times[-1]} s)"
|
|
)
|
|
|
|
if callable(order):
|
|
order = order(times, data)
|
|
|
|
if order is not None:
|
|
if len(order) != n_epochs:
|
|
raise ValueError(
|
|
f"If order is a {type(order).__name__}, its "
|
|
f"length ({len(order)}) must match the length of "
|
|
f"the data ({n_epochs})."
|
|
)
|
|
order = np.array(order)
|
|
data = data[order]
|
|
if overlay_times is not None:
|
|
overlay_times = overlay_times[order]
|
|
|
|
return data, overlay_times
|
|
|
|
|
|
def _plot_epochs_image(
|
|
image,
|
|
style_axes=True,
|
|
epochs=None,
|
|
picks=None,
|
|
vmin=None,
|
|
vmax=None,
|
|
colorbar=False,
|
|
show=False,
|
|
unit=None,
|
|
cmap=None,
|
|
ax=None,
|
|
overlay_times=None,
|
|
title=None,
|
|
evoked=False,
|
|
ts_args=None,
|
|
combine=None,
|
|
combine_given=False,
|
|
norm=False,
|
|
):
|
|
"""Plot epochs image. Helper function for plot_epochs_image."""
|
|
from matplotlib.ticker import AutoLocator
|
|
|
|
if cmap is None:
|
|
cmap = "Reds" if norm else "RdBu_r"
|
|
|
|
tmin = epochs.times[0]
|
|
tmax = epochs.times[-1]
|
|
|
|
ax_im = ax["image"]
|
|
# draw the image
|
|
cmap = _setup_cmap(cmap, norm=norm)
|
|
n_epochs = len(image)
|
|
extent = [tmin, tmax, 0, n_epochs]
|
|
im = ax_im.imshow(
|
|
image,
|
|
vmin=vmin,
|
|
vmax=vmax,
|
|
cmap=cmap[0],
|
|
aspect="auto",
|
|
origin="lower",
|
|
interpolation="nearest",
|
|
extent=extent,
|
|
)
|
|
|
|
# optional things
|
|
if style_axes:
|
|
ax_im.set_title(title)
|
|
ax_im.set_ylabel("Epochs")
|
|
if not evoked:
|
|
ax_im.set_xlabel("Time (s)")
|
|
ax_im.axis("auto")
|
|
ax_im.axis("tight")
|
|
ax_im.axvline(0, color="k", linewidth=1, linestyle="--")
|
|
|
|
if overlay_times is not None:
|
|
ax_im.plot(overlay_times, 0.5 + np.arange(n_epochs), "k", linewidth=2)
|
|
ax_im.set_xlim(tmin, tmax)
|
|
# draw the evoked
|
|
if evoked:
|
|
from .evoked import plot_compare_evokeds
|
|
|
|
pass_combine = combine if combine_given else None
|
|
_picks = [0] if len(picks) == 1 else None # prevent applying GFP
|
|
plot_compare_evokeds(
|
|
{"cond": list(epochs.iter_evoked(copy=False))},
|
|
picks=_picks,
|
|
axes=ax["evoked"],
|
|
combine=pass_combine,
|
|
**ts_args,
|
|
)
|
|
ax["evoked"].set_xlim(tmin, tmax)
|
|
ax["evoked"].lines[0].set_clip_on(True)
|
|
ax["evoked"].collections[0].set_clip_on(True)
|
|
ax["evoked"].sharex(ax_im)
|
|
# fix the axes for proper updating during interactivity
|
|
loc = ax_im.xaxis.get_major_locator()
|
|
ax["evoked"].xaxis.set_major_locator(loc)
|
|
ax["evoked"].yaxis.set_major_locator(AutoLocator())
|
|
|
|
fig = ax_im.get_figure()
|
|
|
|
# draw the colorbar
|
|
if colorbar:
|
|
if "colorbar" in ax: # axes supplied by user
|
|
cax = ax["colorbar"]
|
|
this_colorbar = cax.figure.colorbar(im, cax=cax)
|
|
this_colorbar.ax.set_ylabel(unit, rotation=270, labelpad=12)
|
|
else: # we created them
|
|
this_colorbar = fig.colorbar(im, ax=ax_im)
|
|
this_colorbar.ax.set_title(unit)
|
|
if cmap[1]:
|
|
ax_im.CB = DraggableColorbar(
|
|
this_colorbar, im, kind="epochs_image", ch_type=unit
|
|
)
|
|
|
|
# finish
|
|
plt_show(show, fig=fig)
|
|
return fig
|
|
|
|
|
|
def plot_drop_log(
|
|
drop_log,
|
|
threshold=0,
|
|
n_max_plot=20,
|
|
subject=None,
|
|
color="lightgray",
|
|
width=0.8,
|
|
ignore=("IGNORED",),
|
|
show=True,
|
|
):
|
|
"""Show the channel stats based on a drop_log from Epochs.
|
|
|
|
Parameters
|
|
----------
|
|
drop_log : list of list
|
|
Epoch drop log from Epochs.drop_log.
|
|
threshold : float
|
|
The percentage threshold to use to decide whether or not to
|
|
plot. Default is zero (always plot).
|
|
n_max_plot : int
|
|
Maximum number of channels to show stats for.
|
|
subject : str | None
|
|
The subject name to use in the title of the plot. If ``None``, do not
|
|
display a subject name.
|
|
|
|
.. versionchanged:: 0.23
|
|
Added support for ``None``.
|
|
|
|
.. versionchanged:: 1.0
|
|
Defaults to ``None``.
|
|
color : tuple | str
|
|
Color to use for the bars.
|
|
width : float
|
|
Width of the bars.
|
|
ignore : list
|
|
The drop reasons to ignore.
|
|
show : bool
|
|
Show figure if True.
|
|
|
|
Returns
|
|
-------
|
|
fig : instance of matplotlib.figure.Figure
|
|
The figure.
|
|
"""
|
|
import matplotlib.pyplot as plt
|
|
|
|
from ..epochs import _drop_log_stats
|
|
|
|
percent = _drop_log_stats(drop_log, ignore)
|
|
if percent < threshold:
|
|
logger.info(
|
|
"Percent dropped epochs < supplied threshold; not plotting drop log."
|
|
)
|
|
return
|
|
absolute = len([x for x in drop_log if len(x) if not any(y in ignore for y in x)])
|
|
n_epochs_before_drop = len([x for x in drop_log if not any(y in ignore for y in x)])
|
|
|
|
scores = Counter([ch for d in drop_log for ch in d if ch not in ignore])
|
|
ch_names = np.array(list(scores.keys()))
|
|
counts = np.array(list(scores.values()))
|
|
# init figure, handle easy case (no drops)
|
|
fig, ax = plt.subplots(layout="constrained")
|
|
title = f"{absolute} of {n_epochs_before_drop} epochs removed ({percent:.1f}%)"
|
|
if subject is not None:
|
|
title = f"{subject}: {title}"
|
|
ax.set_title(title)
|
|
if len(ch_names) == 0:
|
|
ax.text(0.5, 0.5, "No drops", ha="center", fontsize=14)
|
|
return fig
|
|
# count epochs that aren't fully caught by `ignore`
|
|
n_used = sum([any(ch not in ignore for ch in d) or len(d) == 0 for d in drop_log])
|
|
# calc plot values
|
|
n_bars = min(n_max_plot, len(ch_names))
|
|
x = np.arange(n_bars)
|
|
y = 100 * counts / n_used
|
|
order = np.flipud(np.argsort(y))
|
|
ax.bar(x, y[order[:n_bars]], color=color, width=width, align="center")
|
|
ax.set_xticks(x)
|
|
ax.set_xticklabels(
|
|
ch_names[order[:n_bars]], rotation=45, size=10, horizontalalignment="right"
|
|
)
|
|
ax.set_ylabel("% of epochs removed")
|
|
ax.grid(axis="y")
|
|
plt_show(show)
|
|
return fig
|
|
|
|
|
|
@fill_doc
|
|
def plot_epochs(
|
|
epochs,
|
|
picks=None,
|
|
scalings=None,
|
|
n_epochs=20,
|
|
n_channels=20,
|
|
title=None,
|
|
events=False,
|
|
event_color=None,
|
|
order=None,
|
|
show=True,
|
|
block=False,
|
|
decim="auto",
|
|
noise_cov=None,
|
|
butterfly=False,
|
|
show_scrollbars=True,
|
|
show_scalebars=True,
|
|
epoch_colors=None,
|
|
event_id=None,
|
|
group_by="type",
|
|
precompute=None,
|
|
use_opengl=None,
|
|
*,
|
|
theme=None,
|
|
overview_mode=None,
|
|
splash=True,
|
|
):
|
|
"""Visualize epochs.
|
|
|
|
Bad epochs can be marked with a left click on top of the epoch. Bad
|
|
channels can be selected by clicking the channel name on the left side of
|
|
the main axes. Calling this function drops all the selected bad epochs as
|
|
well as bad epochs marked beforehand with rejection parameters.
|
|
|
|
Parameters
|
|
----------
|
|
epochs : instance of Epochs
|
|
The epochs object.
|
|
%(picks_good_data)s
|
|
%(scalings)s
|
|
n_epochs : int
|
|
The number of epochs per view. Defaults to 20.
|
|
n_channels : int
|
|
The number of channels per view. Defaults to 20.
|
|
title : str | None
|
|
The title of the window. If None, the event names (from
|
|
``epochs.event_id``) will be displayed. Defaults to None.
|
|
events : bool | array, shape (n_events, 3)
|
|
Events to show with vertical bars. You can use `~mne.viz.plot_events`
|
|
as a legend for the colors. By default, the coloring scheme is the
|
|
same. ``True`` plots ``epochs.events``. Defaults to ``False`` (do not
|
|
plot events).
|
|
|
|
.. warning:: If the epochs have been resampled, the events no longer
|
|
align with the data.
|
|
|
|
.. versionadded:: 0.14.0
|
|
|
|
.. versionchanged:: 1.6
|
|
Passing ``events=None`` was disallowed.
|
|
The new equivalent is ``events=False``.
|
|
%(event_color)s
|
|
Defaults to ``None``.
|
|
order : array of str | None
|
|
Order in which to plot channel types.
|
|
|
|
.. versionadded:: 0.18.0
|
|
show : bool
|
|
Show figure if True. Defaults to True.
|
|
block : bool
|
|
Whether to halt program execution until the figure is closed.
|
|
Useful for rejecting bad trials on the fly by clicking on an epoch.
|
|
Defaults to False.
|
|
decim : int | 'auto'
|
|
Amount to decimate the data during display for speed purposes.
|
|
You should only decimate if the data are sufficiently low-passed,
|
|
otherwise aliasing can occur. The 'auto' mode (default) uses
|
|
the decimation that results in a sampling rate at least three times
|
|
larger than ``info['lowpass']`` (e.g., a 40 Hz lowpass will result in
|
|
at least a 120 Hz displayed sample rate).
|
|
|
|
.. versionadded:: 0.15.0
|
|
noise_cov : instance of Covariance | str | None
|
|
Noise covariance used to whiten the data while plotting.
|
|
Whitened data channels are scaled by ``scalings['whitened']``,
|
|
and their channel names are shown in italic.
|
|
Can be a string to load a covariance from disk.
|
|
See also :meth:`mne.Evoked.plot_white` for additional inspection
|
|
of noise covariance properties when whitening evoked data.
|
|
For data processed with SSS, the effective dependence between
|
|
magnetometers and gradiometers may introduce differences in scaling,
|
|
consider using :meth:`mne.Evoked.plot_white`.
|
|
|
|
.. versionadded:: 0.16.0
|
|
butterfly : bool
|
|
Whether to directly call the butterfly view.
|
|
|
|
.. versionadded:: 0.18.0
|
|
%(show_scrollbars)s
|
|
%(show_scalebars)s
|
|
|
|
.. versionadded:: 0.24.0
|
|
epoch_colors : list of (n_epochs) list (of n_channels) | None
|
|
Colors to use for individual epochs. If None, use default colors.
|
|
event_id : bool | dict
|
|
Determines to label the event markers on the plot. If ``True``, uses
|
|
``epochs.event_id``. If ``False``, uses integer event codes instead of IDs.
|
|
If a ``dict`` is passed, uses its *keys* as event labels on the plot for
|
|
entries whose *values* are integer codes for events being drawn. Ignored if
|
|
``events=False``.
|
|
|
|
.. versionadded:: 0.20
|
|
%(group_by_browse)s
|
|
%(precompute)s
|
|
%(use_opengl)s
|
|
%(theme_pg)s
|
|
|
|
.. versionadded:: 1.0
|
|
%(overview_mode)s
|
|
|
|
.. versionadded:: 1.1
|
|
%(splash)s
|
|
|
|
.. versionadded:: 1.6
|
|
|
|
Returns
|
|
-------
|
|
%(browser)s
|
|
|
|
Notes
|
|
-----
|
|
The arrow keys (up/down/left/right) can be used to navigate between
|
|
channels and epochs and the scaling can be adjusted with - and + (or =)
|
|
keys, but this depends on the backend matplotlib is configured to use
|
|
(e.g., mpl.use(``TkAgg``) should work). Full screen mode can be toggled
|
|
with f11 key. The amount of epochs and channels per view can be adjusted
|
|
with home/end and page down/page up keys. ``h`` key plots a histogram of
|
|
peak-to-peak values along with the used rejection thresholds. Butterfly
|
|
plot can be toggled with ``b`` key. Left mouse click adds a vertical line
|
|
to the plot. Click 'help' button at bottom left corner of the plotter to
|
|
view all the options.
|
|
|
|
%(notes_2d_backend)s
|
|
|
|
.. versionadded:: 0.10.0
|
|
"""
|
|
from ._figure import _get_browser
|
|
|
|
epochs._handle_empty("raise", "plot")
|
|
epochs.drop_bad()
|
|
info = epochs.info.copy()
|
|
sfreq = info["sfreq"]
|
|
projs = info["projs"]
|
|
projs_on = np.full_like(projs, epochs.proj, dtype=bool)
|
|
if not epochs.proj:
|
|
with info._unlock():
|
|
info["projs"] = list()
|
|
|
|
# handle defaults / check arg validity
|
|
color = _handle_default("color", None)
|
|
scalings = _compute_scalings(scalings, epochs)
|
|
scalings = _handle_default("scalings_plot_raw", scalings)
|
|
if scalings["whitened"] == "auto":
|
|
scalings["whitened"] = 1.0
|
|
units = _handle_default("units", None)
|
|
unit_scalings = _handle_default("scalings", None)
|
|
decim, picks_data = _handle_decim(epochs.info.copy(), decim, None)
|
|
noise_cov = _check_cov(noise_cov, epochs.info)
|
|
_check_option("group_by", group_by, ("selection", "position", "original", "type"))
|
|
# handle event labels
|
|
_validate_type(event_id, (bool, dict, None), "event_id")
|
|
if not event_id: # False or None
|
|
event_id = dict()
|
|
else:
|
|
# make our own copy of the dict
|
|
event_id = dict() if event_id is True else event_id.copy() # to dict
|
|
# TODO: when min py=3.9, change to `epochs.event_id | event_id` (maybe).
|
|
# Passed-in event_id should take precedence, i.e., not replace existing
|
|
# keys *or* repeat existing values. For example, if epochs.event_id has
|
|
# a=1 and passed-in event_id has f=1, the second takes precedence.
|
|
event_values = set(event_id.values())
|
|
event_id.update(
|
|
(k, v)
|
|
for k, v in epochs.event_id.items()
|
|
if k not in event_id and v not in event_values
|
|
)
|
|
event_id_rev = {v: k for k, v in event_id.items()}
|
|
# validate epoch_colors
|
|
_validate_type(epoch_colors, (list, None), "epoch_colors")
|
|
if epoch_colors is not None:
|
|
if len(epoch_colors) != len(epochs.events):
|
|
msg = (
|
|
"epoch_colors must have length equal to the number of "
|
|
f"epochs ({len(epochs)}); got length {len(epoch_colors)}."
|
|
)
|
|
raise ValueError(msg)
|
|
for ix, this_colors in enumerate(epoch_colors):
|
|
_validate_type(this_colors, list, f"epoch_colors[{ix}]")
|
|
if len(this_colors) != len(epochs.ch_names):
|
|
msg = (
|
|
f"epoch colors for epoch {ix} has length "
|
|
f"{len(this_colors)}, expected {len(epochs.ch_names)}."
|
|
)
|
|
raise ValueError(msg)
|
|
|
|
# handle time dimension
|
|
n_epochs = min(n_epochs, len(epochs))
|
|
n_times = len(epochs) * len(epochs.times)
|
|
duration = n_epochs * len(epochs.times) / sfreq
|
|
# NB: this includes start and end of data:
|
|
boundary_times = np.arange(len(epochs) + 1) * len(epochs.times) / sfreq
|
|
|
|
# events
|
|
_validate_type(events, (bool, np.ndarray), "events")
|
|
if events is False:
|
|
event_nums = None
|
|
event_times = None
|
|
else: # True or ndarray
|
|
if events is True: # use epochs.events
|
|
events = epochs.events
|
|
event_nums = events[:, 2]
|
|
event_samps = events[:, 0]
|
|
epoch_n_samps = len(epochs.times)
|
|
# handle overlapping epochs (each event may show up in multiple places)
|
|
boundaries = epochs.events[:, [0]] + np.array([-1, 1]) * epochs.time_as_index(
|
|
[0, epochs.tmax]
|
|
)
|
|
in_bounds = np.logical_and(
|
|
boundaries[:, [0]] <= event_samps, event_samps < boundaries[:, [1]]
|
|
)
|
|
event_ixs = [np.nonzero(a)[0] for a in in_bounds.T]
|
|
warned = False
|
|
event_times = list()
|
|
event_numbers = list()
|
|
for samp, num, _ixs in zip(event_samps, event_nums, event_ixs):
|
|
relevant_epoch_events = epochs.events[:, 0][_ixs]
|
|
if len(relevant_epoch_events) > 1 and not warned:
|
|
logger.info(
|
|
"You seem to have overlapping epochs. Some event "
|
|
"lines may be duplicated in the plot."
|
|
)
|
|
warned = True
|
|
offsets = samp - relevant_epoch_events + epochs.time_as_index(0)
|
|
this_event_times = (_ixs * epoch_n_samps + offsets) / sfreq
|
|
event_times.extend(this_event_times)
|
|
event_numbers.extend([num] * len(_ixs))
|
|
event_nums = np.array(event_numbers)
|
|
event_times = np.array(event_times)
|
|
|
|
event_color_dict = _make_event_color_dict(event_color, events, event_id)
|
|
|
|
# determine trace order
|
|
picks = _picks_to_idx(info, picks)
|
|
n_channels = min(n_channels, len(picks))
|
|
ch_names = np.array(epochs.ch_names)
|
|
ch_types = np.array(epochs.get_channel_types())
|
|
order = _get_channel_plotting_order(order, ch_types, picks)
|
|
selections = None
|
|
if group_by in ("selection", "position"):
|
|
selections = _setup_channel_selections(epochs, group_by, order)
|
|
order = np.concatenate(list(selections.values()))
|
|
default_selection = list(selections)[0]
|
|
n_channels = len(selections[default_selection])
|
|
|
|
# generate window title
|
|
if title is None:
|
|
title = epochs._get_name(count="total", sep="•", ms=None)
|
|
if title is None or len(title) == 0:
|
|
title = "Epochs"
|
|
elif not isinstance(title, str):
|
|
raise TypeError(f"title must be None or a string, got a {type(title)}")
|
|
|
|
precompute = _handle_precompute(precompute)
|
|
params = dict(
|
|
inst=epochs,
|
|
info=info,
|
|
n_epochs=n_epochs,
|
|
# channels and channel order
|
|
ch_names=ch_names,
|
|
ch_types=ch_types,
|
|
ch_order=order,
|
|
picks=order[:n_channels],
|
|
n_channels=n_channels,
|
|
picks_data=picks_data,
|
|
group_by=group_by,
|
|
ch_selections=selections,
|
|
# time
|
|
t_start=0,
|
|
duration=duration,
|
|
n_times=n_times,
|
|
first_time=0,
|
|
time_format="float",
|
|
decim=decim,
|
|
boundary_times=boundary_times,
|
|
# events
|
|
event_id_rev=event_id_rev,
|
|
event_color_dict=event_color_dict,
|
|
event_nums=event_nums,
|
|
event_times=event_times,
|
|
# preprocessing
|
|
projs=projs,
|
|
projs_on=projs_on,
|
|
apply_proj=epochs.proj,
|
|
remove_dc=True,
|
|
filter_coefs=None,
|
|
filter_bounds=None,
|
|
noise_cov=noise_cov,
|
|
use_noise_cov=noise_cov is not None,
|
|
# scalings
|
|
scalings=scalings,
|
|
units=units,
|
|
unit_scalings=unit_scalings,
|
|
# colors
|
|
ch_color_bad="lightgray",
|
|
ch_color_dict=color,
|
|
epoch_color_bad=(1, 0, 0),
|
|
epoch_colors=epoch_colors,
|
|
# display
|
|
butterfly=butterfly,
|
|
clipping=None,
|
|
scrollbars_visible=show_scrollbars,
|
|
scalebars_visible=show_scalebars,
|
|
window_title=title,
|
|
xlabel="Epoch number",
|
|
# pyqtgraph-specific
|
|
precompute=precompute,
|
|
use_opengl=use_opengl,
|
|
theme=theme,
|
|
overview_mode=overview_mode,
|
|
splash=splash,
|
|
)
|
|
|
|
fig = _get_browser(show=show, block=block, **params)
|
|
|
|
return fig
|
|
|
|
|
|
@legacy(alt="Epochs.compute_psd().plot()")
|
|
@verbose
|
|
def plot_epochs_psd(
|
|
epochs,
|
|
fmin=0,
|
|
fmax=np.inf,
|
|
tmin=None,
|
|
tmax=None,
|
|
proj=False,
|
|
bandwidth=None,
|
|
adaptive=False,
|
|
low_bias=True,
|
|
normalization="length",
|
|
picks=None,
|
|
ax=None,
|
|
color="black",
|
|
xscale="linear",
|
|
area_mode="std",
|
|
area_alpha=0.33,
|
|
dB=True,
|
|
estimate="power",
|
|
show=True,
|
|
n_jobs=None,
|
|
average=False,
|
|
line_alpha=None,
|
|
spatial_colors=True,
|
|
sphere=None,
|
|
exclude="bads",
|
|
verbose=None,
|
|
):
|
|
"""%(plot_psd_doc)s.
|
|
|
|
Parameters
|
|
----------
|
|
epochs : instance of Epochs
|
|
The epochs object.
|
|
%(fmin_fmax_psd)s
|
|
%(tmin_tmax_psd)s
|
|
%(proj_psd)s
|
|
bandwidth : float
|
|
The bandwidth of the multi taper windowing function in Hz. The default
|
|
value is a window half-bandwidth of 4.
|
|
adaptive : bool
|
|
Use adaptive weights to combine the tapered spectra into PSD
|
|
(slow, use n_jobs >> 1 to speed up computation).
|
|
low_bias : bool
|
|
Only use tapers with more than 90%% spectral concentration within
|
|
bandwidth.
|
|
%(normalization)s
|
|
%(picks_good_data_noref)s
|
|
%(ax_plot_psd)s
|
|
%(color_plot_psd)s
|
|
%(xscale_plot_psd)s
|
|
%(area_mode_plot_psd)s
|
|
%(area_alpha_plot_psd)s
|
|
%(dB_plot_psd)s
|
|
%(estimate_plot_psd)s
|
|
%(show)s
|
|
%(n_jobs)s
|
|
%(average_plot_psd)s
|
|
%(line_alpha_plot_psd)s
|
|
%(spatial_colors_psd)s
|
|
%(sphere_topomap_auto)s
|
|
exclude : list of str | 'bads'
|
|
Channels names to exclude from being shown. If 'bads', the bad channels
|
|
are excluded. Pass an empty list to plot all channels (including
|
|
channels marked "bad", if any).
|
|
|
|
.. versionadded:: 0.24.0
|
|
%(verbose)s
|
|
|
|
Returns
|
|
-------
|
|
fig : instance of Figure
|
|
Figure with frequency spectra of the data channels.
|
|
|
|
Notes
|
|
-----
|
|
%(notes_plot_*_psd_func)s
|
|
"""
|
|
from ..time_frequency import Spectrum
|
|
|
|
init_kw, plot_kw = _split_psd_kwargs(plot_fun=Spectrum.plot)
|
|
return epochs.compute_psd(**init_kw).plot(**plot_kw)
|