2531 lines
103 KiB
Python
2531 lines
103 KiB
Python
"""Figure classes for MNE-Python's 2D plots.
|
||
|
||
Class Hierarchy
|
||
---------------
|
||
|
||
MNEFigParams Container object, attached to MNEFigure by default. Sets
|
||
close_key='escape' plus whatever other key-value pairs are
|
||
passed to its constructor.
|
||
|
||
matplotlib.figure.Figure
|
||
└ MNEFigure
|
||
├ MNEBrowseFigure Interactive figure for scrollable data.
|
||
│ Generated by:
|
||
│ - raw.plot()
|
||
│ - epochs.plot()
|
||
│ - ica.plot_sources(raw)
|
||
│ - ica.plot_sources(epochs)
|
||
│
|
||
├ MNEAnnotationFigure GUI for adding annotations to Raw
|
||
│
|
||
├ MNESelectionFigure GUI for spatial channel selection. raw.plot()
|
||
│ and epochs.plot() will generate one of these
|
||
│ alongside an MNEBrowseFigure when
|
||
│ group_by == 'selection' or 'position'
|
||
│
|
||
└ MNELineFigure Interactive figure for non-scrollable data.
|
||
Generated by:
|
||
- spectrum.plot()
|
||
- evoked.plot() TODO Not yet implemented
|
||
- evoked.plot_white() TODO Not yet implemented
|
||
- evoked.plot_joint() TODO Not yet implemented
|
||
"""
|
||
|
||
# Authors: The MNE-Python contributors.
|
||
# License: BSD-3-Clause
|
||
# Copyright the MNE-Python contributors.
|
||
|
||
import datetime
|
||
import platform
|
||
from collections import OrderedDict
|
||
from contextlib import contextmanager
|
||
from functools import partial
|
||
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
from matplotlib import get_backend
|
||
from matplotlib.figure import Figure
|
||
|
||
from .._fiff.pick import (
|
||
_DATA_CH_TYPES_ORDER_DEFAULT,
|
||
_DATA_CH_TYPES_SPLIT,
|
||
_EYETRACK_CH_TYPES_SPLIT,
|
||
_FNIRS_CH_TYPES_SPLIT,
|
||
_VALID_CHANNEL_TYPES,
|
||
channel_indices_by_type,
|
||
pick_types,
|
||
)
|
||
from ..fixes import _close_event
|
||
from ..utils import Bunch, _click_ch_name, check_version, logger
|
||
from ._figure import BrowserBase
|
||
from .utils import (
|
||
DraggableLine,
|
||
_events_off,
|
||
_fake_click,
|
||
_fake_keypress,
|
||
_fake_scroll,
|
||
_merge_annotations,
|
||
_set_window_title,
|
||
_validate_if_list_of_axes,
|
||
plot_sensors,
|
||
plt_show,
|
||
)
|
||
|
||
name = "matplotlib"
|
||
BACKEND = get_backend()
|
||
|
||
# CONSTANTS (inches)
|
||
ANNOTATION_FIG_PAD = 0.1
|
||
ANNOTATION_FIG_MIN_H = 2.9 # fixed part, not including radio buttons/labels
|
||
ANNOTATION_FIG_W = 5.0
|
||
ANNOTATION_FIG_CHECKBOX_COLUMN_W = 0.5
|
||
_OLD_BUTTONS = not check_version("matplotlib", "3.7")
|
||
|
||
|
||
class MNEFigure(Figure):
|
||
"""Base class for 2D figures & dialogs; wraps matplotlib.figure.Figure."""
|
||
|
||
def __init__(self, **kwargs):
|
||
from matplotlib import rcParams
|
||
|
||
# figsize is the only kwarg we pass to matplotlib Figure()
|
||
figsize = kwargs.pop("figsize", None)
|
||
super().__init__(figsize=figsize)
|
||
# things we'll almost always want
|
||
defaults = dict(
|
||
fgcolor=rcParams["axes.edgecolor"], bgcolor=rcParams["axes.facecolor"]
|
||
)
|
||
for key, value in defaults.items():
|
||
if key not in kwargs:
|
||
kwargs[key] = value
|
||
|
||
# add param object if not already added (e.g. by BrowserBase)
|
||
if not hasattr(self, "mne"):
|
||
from mne.viz._figure import BrowserParams
|
||
|
||
self.mne = BrowserParams(**kwargs)
|
||
else:
|
||
for key in [k for k in kwargs if not hasattr(self.mne, k)]:
|
||
setattr(self.mne, key, kwargs[key])
|
||
|
||
def _close(self, event=None):
|
||
"""Handle close events."""
|
||
logger.debug(f"Closing {self!r}")
|
||
# remove references from parent fig to child fig
|
||
is_child = getattr(self.mne, "parent_fig", None) is not None
|
||
is_named = getattr(self.mne, "fig_name", None) is not None
|
||
if is_child:
|
||
try:
|
||
self.mne.parent_fig.mne.child_figs.remove(self)
|
||
except ValueError:
|
||
pass # already removed (on its own, probably?)
|
||
if is_named:
|
||
setattr(self.mne.parent_fig.mne, self.mne.fig_name, None)
|
||
|
||
def _keypress(self, event):
|
||
"""Handle keypress events."""
|
||
if event.key == self.mne.close_key:
|
||
plt.close(self)
|
||
elif event.key == "f11": # full screen
|
||
self.canvas.manager.full_screen_toggle()
|
||
|
||
def _buttonpress(self, event):
|
||
"""Handle buttonpress events."""
|
||
pass
|
||
|
||
def _pick(self, event):
|
||
"""Handle matplotlib pick events."""
|
||
pass
|
||
|
||
def _resize(self, event):
|
||
"""Handle window resize events."""
|
||
pass
|
||
|
||
def _add_default_callbacks(self, **kwargs):
|
||
"""Remove some matplotlib default callbacks and add MNE-Python ones."""
|
||
# Remove matplotlib default keypress catchers
|
||
default_callbacks = list(
|
||
self.canvas.callbacks.callbacks.get("key_press_event", {})
|
||
)
|
||
for callback in default_callbacks:
|
||
self.canvas.callbacks.disconnect(callback)
|
||
# add our event callbacks
|
||
callbacks = dict(
|
||
resize_event=self._resize,
|
||
key_press_event=self._keypress,
|
||
button_press_event=self._buttonpress,
|
||
close_event=self._close,
|
||
pick_event=self._pick,
|
||
)
|
||
callbacks.update(kwargs)
|
||
callback_ids = dict()
|
||
for event, callback in callbacks.items():
|
||
callback_ids[event] = self.canvas.mpl_connect(event, callback)
|
||
# store callback references so they aren't garbage-collected
|
||
self.mne._callback_ids = callback_ids
|
||
|
||
def _get_dpi_ratio(self):
|
||
"""Get DPI ratio (to handle hi-DPI screens)."""
|
||
dpi_ratio = 1.0
|
||
for key in ("_dpi_ratio", "_device_scale"):
|
||
dpi_ratio = getattr(self.canvas, key, dpi_ratio)
|
||
return dpi_ratio
|
||
|
||
def _get_size_px(self):
|
||
"""Get figure size in pixels."""
|
||
dpi_ratio = self._get_dpi_ratio()
|
||
return self.get_size_inches() * self.dpi / dpi_ratio
|
||
|
||
def _inch_to_rel(self, dim_inches, horiz=True):
|
||
"""Convert inches to figure-relative distances."""
|
||
fig_w, fig_h = self.get_size_inches()
|
||
w_or_h = fig_w if horiz else fig_h
|
||
return dim_inches / w_or_h
|
||
|
||
|
||
class MNEAnnotationFigure(MNEFigure):
|
||
"""Interactive dialog figure for annotations."""
|
||
|
||
def _close(self, event):
|
||
"""Handle close events (via keypress or window [x])."""
|
||
parent = self.mne.parent_fig
|
||
# disable span selector
|
||
parent.mne.ax_main.selector.active = False
|
||
# clear hover line
|
||
parent._remove_annotation_hover_line()
|
||
# disconnect hover callback
|
||
callback_id = parent.mne._callback_ids["motion_notify_event"]
|
||
parent.canvas.callbacks.disconnect(callback_id)
|
||
# do all the other cleanup activities
|
||
super()._close(event)
|
||
|
||
def _keypress(self, event):
|
||
"""Handle keypress events."""
|
||
text = self.label.get_text()
|
||
key = event.key
|
||
if key == self.mne.close_key:
|
||
plt.close(self)
|
||
elif key == "backspace":
|
||
text = text[:-1]
|
||
elif key == "enter":
|
||
self.mne.parent_fig._add_annotation_label(event)
|
||
return
|
||
elif len(key) > 1 or key == ";": # ignore modifier keys
|
||
return
|
||
else:
|
||
text = text + key
|
||
self.label.set_text(text)
|
||
self.canvas.draw()
|
||
|
||
def _radiopress(self, event, *, draw=True):
|
||
"""Handle Radiobutton clicks for Annotation label selection."""
|
||
# update which button looks active
|
||
buttons = self.mne.radio_ax.buttons
|
||
labels = [label.get_text() for label in buttons.labels]
|
||
idx = labels.index(buttons.value_selected)
|
||
self._set_active_button(idx, draw=False)
|
||
# update click-drag rectangle color
|
||
color = self.mne.parent_fig.mne.annotation_segment_colors[labels[idx]]
|
||
selector = self.mne.parent_fig.mne.ax_main.selector
|
||
# https://github.com/matplotlib/matplotlib/issues/20618
|
||
# https://github.com/matplotlib/matplotlib/pull/20693
|
||
selector.set_props(color=color, facecolor=color)
|
||
if draw:
|
||
self.canvas.draw()
|
||
|
||
def _click_override(self, event):
|
||
"""Override MPL radiobutton click detector to use transData."""
|
||
assert _OLD_BUTTONS
|
||
ax = self.mne.radio_ax
|
||
buttons = ax.buttons
|
||
if buttons.ignore(event) or event.button != 1 or event.inaxes != ax:
|
||
return
|
||
pclicked = ax.transData.inverted().transform((event.x, event.y))
|
||
distances = {}
|
||
for i, (p, t) in enumerate(zip(buttons.circles, buttons.labels)):
|
||
if (
|
||
t.get_window_extent().contains(event.x, event.y)
|
||
or np.linalg.norm(pclicked - p.center) < p.radius
|
||
):
|
||
distances[i] = np.linalg.norm(pclicked - p.center)
|
||
if len(distances) > 0:
|
||
closest = min(distances, key=distances.get)
|
||
buttons.set_active(closest)
|
||
|
||
def _set_active_button(self, idx, *, draw=True):
|
||
"""Set active button in annotation dialog figure."""
|
||
buttons = self.mne.radio_ax.buttons
|
||
logger.debug(f"buttons: {buttons}")
|
||
logger.debug(f"active idx: {idx}")
|
||
with _events_off(buttons):
|
||
buttons.set_active(idx)
|
||
if _OLD_BUTTONS:
|
||
logger.debug(f"circles: {buttons.circles}")
|
||
for circle in buttons.circles:
|
||
circle.set_facecolor(self.mne.parent_fig.mne.bgcolor)
|
||
# active circle gets filled in, partially transparent
|
||
color = list(buttons.circles[idx].get_edgecolor())
|
||
logger.debug(f"color: {color}")
|
||
color[-1] = 0.5
|
||
buttons.circles[idx].set_facecolor(color)
|
||
if draw:
|
||
self.canvas.draw()
|
||
|
||
|
||
class MNESelectionFigure(MNEFigure):
|
||
"""Interactive dialog figure for channel selections."""
|
||
|
||
def _close(self, event):
|
||
"""Handle close events."""
|
||
self.mne.parent_fig.mne.child_figs.remove(self)
|
||
self.mne.fig_selection = None
|
||
# selection fig & main fig tightly integrated; closing one closes both
|
||
plt.close(self.mne.parent_fig)
|
||
|
||
def _keypress(self, event):
|
||
"""Handle keypress events."""
|
||
if event.key in ("up", "down", "b"):
|
||
self.mne.parent_fig._keypress(event)
|
||
else: # check for close key
|
||
super()._keypress(event)
|
||
|
||
def _radiopress(self, event):
|
||
"""Handle RadioButton clicks for channel selection groups."""
|
||
logger.debug(f"Got radio press: {repr(event)}")
|
||
selections_dict = self.mne.parent_fig.mne.ch_selections
|
||
buttons = self.mne.radio_ax.buttons
|
||
labels = [label.get_text() for label in buttons.labels]
|
||
this_label = buttons.value_selected
|
||
parent = self.mne.parent_fig
|
||
if this_label == "Custom" and not len(selections_dict["Custom"]):
|
||
with _events_off(buttons):
|
||
buttons.set_active(self.mne.old_selection)
|
||
return
|
||
# clicking a selection cancels butterfly mode
|
||
if parent.mne.butterfly:
|
||
logger.debug("Disabling butterfly mode")
|
||
parent._toggle_butterfly()
|
||
with _events_off(buttons):
|
||
buttons.set_active(labels.index(this_label))
|
||
parent._update_selection()
|
||
|
||
def _set_custom_selection(self):
|
||
"""Set custom selection by lasso selector."""
|
||
chs = self.lasso.selection
|
||
parent = self.mne.parent_fig
|
||
buttons = self.mne.radio_ax.buttons
|
||
if not len(chs):
|
||
return
|
||
labels = [label.get_text() for label in buttons.labels]
|
||
inds = np.isin(parent.mne.ch_names, chs)
|
||
parent.mne.ch_selections["Custom"] = inds.nonzero()[0]
|
||
buttons.set_active(labels.index("Custom"))
|
||
|
||
def _style_radio_buttons_butterfly(self):
|
||
"""Handle RadioButton state for keyboard interactions."""
|
||
# Show all radio buttons as selected when in butterfly mode
|
||
parent = self.mne.parent_fig
|
||
buttons = self.mne.radio_ax.buttons
|
||
color = buttons.activecolor if parent.mne.butterfly else parent.mne.bgcolor
|
||
if _OLD_BUTTONS:
|
||
for circle in buttons.circles:
|
||
circle.set_facecolor(color)
|
||
# when leaving butterfly mode, make most-recently-used selection active
|
||
if not parent.mne.butterfly:
|
||
with _events_off(buttons):
|
||
buttons.set_active(self.mne.old_selection)
|
||
# update the sensors too
|
||
parent._update_highlighted_sensors()
|
||
|
||
|
||
class MNEBrowseFigure(BrowserBase, MNEFigure):
|
||
"""Interactive figure with scrollbars, for data browsing."""
|
||
|
||
def __init__(self, inst, figsize, ica=None, xlabel="Time (s)", **kwargs):
|
||
from matplotlib.colors import to_rgba_array
|
||
from matplotlib.patches import Rectangle
|
||
from matplotlib.ticker import (
|
||
FixedFormatter,
|
||
FixedLocator,
|
||
FuncFormatter,
|
||
NullFormatter,
|
||
)
|
||
from matplotlib.transforms import blended_transform_factory
|
||
from matplotlib.widgets import Button
|
||
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
|
||
from mpl_toolkits.axes_grid1.axes_size import Fixed
|
||
|
||
self.backend_name = "matplotlib"
|
||
|
||
kwargs.update({"inst": inst, "figsize": figsize, "ica": ica, "xlabel": xlabel})
|
||
|
||
BrowserBase.__init__(self, **kwargs)
|
||
MNEFigure.__init__(self, **kwargs)
|
||
|
||
# MAIN AXES: default sizes (inches)
|
||
# XXX simpler with constrained_layout? (when it's no longer "beta")
|
||
l_margin = 1.0
|
||
r_margin = 0.1
|
||
b_margin = 0.45
|
||
t_margin = 0.25
|
||
scroll_width = 0.25
|
||
hscroll_dist = 0.25
|
||
vscroll_dist = 0.1
|
||
help_width = scroll_width * 2
|
||
# MAIN AXES: default margins (figure-relative coordinates)
|
||
left = self._inch_to_rel(l_margin - vscroll_dist - help_width)
|
||
right = 1 - self._inch_to_rel(r_margin)
|
||
bottom = self._inch_to_rel(b_margin, horiz=False)
|
||
top = 1 - self._inch_to_rel(t_margin, horiz=False)
|
||
width = right - left
|
||
height = top - bottom
|
||
position = [left, bottom, width, height]
|
||
# Main axes must be a subplot for subplots_adjust to work (so user can
|
||
# adjust margins). That's why we don't use the Divider class directly.
|
||
ax_main = self.add_subplot(1, 1, 1, position=position)
|
||
self.subplotpars.update(left=left, bottom=bottom, top=top, right=right)
|
||
div = make_axes_locatable(ax_main)
|
||
# this only gets shown in zen mode
|
||
self.mne.zen_xlabel = ax_main.set_xlabel(xlabel)
|
||
self.mne.zen_xlabel.set_visible(not self.mne.scrollbars_visible)
|
||
# make sure background color of the axis is set
|
||
if "bgcolor" in kwargs:
|
||
ax_main.set_facecolor(kwargs["bgcolor"])
|
||
|
||
# SCROLLBARS
|
||
ax_hscroll = div.append_axes(
|
||
position="bottom", size=Fixed(scroll_width), pad=Fixed(hscroll_dist)
|
||
)
|
||
ax_vscroll = div.append_axes(
|
||
position="right", size=Fixed(scroll_width), pad=Fixed(vscroll_dist)
|
||
)
|
||
ax_hscroll.get_yaxis().set_visible(False)
|
||
ax_hscroll.set_xlabel(xlabel)
|
||
ax_vscroll.set_axis_off()
|
||
# HORIZONTAL SCROLLBAR PATCHES (FOR MARKING BAD EPOCHS)
|
||
if self.mne.is_epochs:
|
||
epoch_nums = self.mne.inst.selection
|
||
for ix, _ in enumerate(epoch_nums):
|
||
start = self.mne.boundary_times[ix]
|
||
width = np.diff(self.mne.boundary_times[:2])[0]
|
||
ax_hscroll.add_patch(
|
||
Rectangle(
|
||
(start, 0),
|
||
width,
|
||
1,
|
||
color="none",
|
||
zorder=self.mne.zorder["patch"],
|
||
)
|
||
)
|
||
# both axes, major ticks: gridlines
|
||
for _ax in (ax_main, ax_hscroll):
|
||
_ax.xaxis.set_major_locator(FixedLocator(self.mne.boundary_times[1:-1]))
|
||
_ax.xaxis.set_major_formatter(NullFormatter())
|
||
grid_kwargs = dict(
|
||
color=self.mne.fgcolor, axis="x", zorder=self.mne.zorder["grid"]
|
||
)
|
||
ax_main.grid(linewidth=2, linestyle="dashed", **grid_kwargs)
|
||
ax_hscroll.grid(alpha=0.5, linewidth=0.5, linestyle="solid", **grid_kwargs)
|
||
# main axes, minor ticks: ticklabel (epoch number) for every epoch
|
||
ax_main.xaxis.set_minor_locator(FixedLocator(self.mne.midpoints))
|
||
ax_main.xaxis.set_minor_formatter(FixedFormatter(epoch_nums))
|
||
# hscroll axes, minor ticks: up to 20 ticklabels (epoch numbers)
|
||
ax_hscroll.xaxis.set_minor_locator(
|
||
FixedLocator(self.mne.midpoints, nbins=20)
|
||
)
|
||
ax_hscroll.xaxis.set_minor_formatter(
|
||
FuncFormatter(lambda x, pos: self._get_epoch_num_from_time(x))
|
||
)
|
||
# hide some ticks
|
||
ax_main.tick_params(axis="x", which="major", bottom=False)
|
||
ax_hscroll.tick_params(axis="x", which="both", bottom=False)
|
||
else:
|
||
# RAW / ICA X-AXIS TICK & LABEL FORMATTING
|
||
ax_main.xaxis.set_major_formatter(
|
||
FuncFormatter(partial(self._xtick_formatter, ax_type="main"))
|
||
)
|
||
ax_hscroll.xaxis.set_major_formatter(
|
||
FuncFormatter(partial(self._xtick_formatter, ax_type="hscroll"))
|
||
)
|
||
if self.mne.time_format != "float":
|
||
for _ax in (ax_main, ax_hscroll):
|
||
_ax.set_xlabel("Time (HH:MM:SS)")
|
||
|
||
# VERTICAL SCROLLBAR PATCHES (COLORED BY CHANNEL TYPE)
|
||
ch_order = self.mne.ch_order
|
||
for ix, pick in enumerate(ch_order):
|
||
this_color = (
|
||
self.mne.ch_color_bad
|
||
if self.mne.ch_names[pick] in self.mne.info["bads"]
|
||
else self.mne.ch_color_dict
|
||
)
|
||
if isinstance(this_color, dict):
|
||
this_color = this_color[self.mne.ch_types[pick]]
|
||
ax_vscroll.add_patch(
|
||
Rectangle(
|
||
(0, ix), 1, 1, color=this_color, zorder=self.mne.zorder["patch"]
|
||
)
|
||
)
|
||
ax_vscroll.set_ylim(len(ch_order), 0)
|
||
ax_vscroll.set_visible(not self.mne.butterfly)
|
||
# SCROLLBAR VISIBLE SELECTION PATCHES
|
||
sel_kwargs = dict(
|
||
alpha=0.3, linewidth=4, clip_on=False, edgecolor=self.mne.fgcolor
|
||
)
|
||
vsel_patch = Rectangle(
|
||
(0, 0), 1, self.mne.n_channels, facecolor=self.mne.bgcolor, **sel_kwargs
|
||
)
|
||
ax_vscroll.add_patch(vsel_patch)
|
||
hsel_facecolor = np.average(
|
||
np.vstack(
|
||
(to_rgba_array(self.mne.fgcolor), to_rgba_array(self.mne.bgcolor))
|
||
),
|
||
axis=0,
|
||
weights=(3, 1),
|
||
) # 75% foreground, 25% background
|
||
hsel_patch = Rectangle(
|
||
(self.mne.t_start, 0),
|
||
self.mne.duration,
|
||
1,
|
||
facecolor=hsel_facecolor,
|
||
**sel_kwargs,
|
||
)
|
||
ax_hscroll.add_patch(hsel_patch)
|
||
ax_hscroll.set_xlim(
|
||
self.mne.first_time,
|
||
self.mne.first_time + self.mne.n_times / self.mne.info["sfreq"],
|
||
)
|
||
# VLINE
|
||
vline_color = (0.0, 0.75, 0.0)
|
||
vline_kwargs = dict(visible=False, zorder=self.mne.zorder["vline"])
|
||
if self.mne.is_epochs:
|
||
x = np.arange(self.mne.n_epochs)
|
||
vline = ax_main.vlines(x, 0, 1, colors=vline_color, **vline_kwargs)
|
||
vline.set_transform(
|
||
blended_transform_factory(ax_main.transData, ax_main.transAxes)
|
||
)
|
||
vline_hscroll = None
|
||
else:
|
||
vline = ax_main.axvline(0, color=vline_color, **vline_kwargs)
|
||
vline_hscroll = ax_hscroll.axvline(0, color=vline_color, **vline_kwargs)
|
||
vline_text = ax_main.annotate(
|
||
"",
|
||
xy=(0, 0),
|
||
xycoords="axes fraction",
|
||
xytext=(-2, 0),
|
||
textcoords="offset points",
|
||
fontsize=10,
|
||
ha="right",
|
||
va="center",
|
||
color=vline_color,
|
||
**vline_kwargs,
|
||
)
|
||
|
||
# HELP BUTTON: initialize in the wrong spot...
|
||
ax_help = div.append_axes(
|
||
position="left", size=Fixed(help_width), pad=Fixed(vscroll_dist)
|
||
)
|
||
# HELP BUTTON: ...move it down by changing its locator
|
||
loc = div.new_locator(nx=0, ny=0)
|
||
ax_help.set_axes_locator(loc)
|
||
# HELP BUTTON: make it a proper button
|
||
with _patched_canvas(ax_help.figure):
|
||
self.mne.button_help = Button(ax_help, "Help")
|
||
# PROJ BUTTON
|
||
ax_proj = None
|
||
if len(self.mne.projs) and not self.mne.inst.proj:
|
||
proj_button_pos = [
|
||
1 - self._inch_to_rel(r_margin + scroll_width), # left
|
||
self._inch_to_rel(b_margin, horiz=False), # bottom
|
||
self._inch_to_rel(scroll_width), # width
|
||
self._inch_to_rel(scroll_width, horiz=False), # height
|
||
]
|
||
loc = div.new_locator(nx=4, ny=0)
|
||
ax_proj = self.add_axes(proj_button_pos)
|
||
ax_proj.set_axes_locator(loc)
|
||
with _patched_canvas(ax_help.figure):
|
||
self.mne.button_proj = Button(ax_proj, "Prj")
|
||
|
||
# INIT TRACES
|
||
self.mne.trace_kwargs = dict(antialiased=True, linewidth=0.5)
|
||
self.mne.traces = ax_main.plot(
|
||
np.full((1, self.mne.n_channels), np.nan), **self.mne.trace_kwargs
|
||
)
|
||
|
||
# SAVE UI ELEMENT HANDLES
|
||
vars(self.mne).update(
|
||
ax_main=ax_main,
|
||
ax_help=ax_help,
|
||
ax_proj=ax_proj,
|
||
ax_hscroll=ax_hscroll,
|
||
ax_vscroll=ax_vscroll,
|
||
vsel_patch=vsel_patch,
|
||
hsel_patch=hsel_patch,
|
||
vline=vline,
|
||
vline_hscroll=vline_hscroll,
|
||
vline_text=vline_text,
|
||
)
|
||
|
||
def _get_size(self):
|
||
return self.get_size_inches()
|
||
|
||
def _resize(self, event):
|
||
"""Handle resize event for mne_browse-style plots (Raw/Epochs/ICA)."""
|
||
old_width, old_height = self.mne.fig_size_px
|
||
new_width, new_height = self._get_size_px()
|
||
new_margins = _calc_new_margins(
|
||
self, old_width, old_height, new_width, new_height
|
||
)
|
||
self.subplots_adjust(**new_margins)
|
||
# zen mode bookkeeping
|
||
self.mne.zen_w *= old_width / new_width
|
||
self.mne.zen_h *= old_height / new_height
|
||
self.mne.fig_size_px = (new_width, new_height)
|
||
self.canvas.draw_idle()
|
||
|
||
def _hover(self, event):
|
||
"""Handle motion event when annotating."""
|
||
if (
|
||
event.button is not None
|
||
or event.xdata is None
|
||
or event.inaxes != self.mne.ax_main
|
||
):
|
||
return
|
||
if not self.mne.draggable_annotations:
|
||
self._remove_annotation_hover_line()
|
||
return
|
||
from matplotlib.patheffects import Normal, Stroke
|
||
|
||
for coll in self.mne.annotations:
|
||
if coll.contains(event)[0]:
|
||
path = coll.get_paths()
|
||
assert len(path) == 1
|
||
path = path[0]
|
||
color = coll.get_edgecolors()[0]
|
||
ylim = self.mne.ax_main.get_ylim()
|
||
# are we on the left or right edge?
|
||
_l = path.vertices[:, 0].min()
|
||
_r = path.vertices[:, 0].max()
|
||
x = _l if abs(event.xdata - _l) < abs(event.xdata - _r) else _r
|
||
mask = path.vertices[:, 0] == x
|
||
|
||
def drag_callback(x0):
|
||
path.vertices[mask, 0] = x0
|
||
|
||
# create or update the DraggableLine
|
||
hover_line = self.mne.annotation_hover_line
|
||
if hover_line is None:
|
||
line = self.mne.ax_main.plot(
|
||
[x, x], ylim, color=color, linewidth=2, pickradius=5.0
|
||
)[0]
|
||
hover_line = DraggableLine(
|
||
line, self._modify_annotation, drag_callback
|
||
)
|
||
else:
|
||
hover_line.set_x(x)
|
||
hover_line.drag_callback = drag_callback
|
||
# style the line
|
||
line = hover_line.line
|
||
patheff = [Stroke(linewidth=4, foreground=color, alpha=0.5), Normal()]
|
||
line.set_path_effects(
|
||
patheff if line.contains(event)[0] else patheff[1:]
|
||
)
|
||
self.mne.ax_main.selector.active = False
|
||
self.mne.annotation_hover_line = hover_line
|
||
self.canvas.draw_idle()
|
||
return
|
||
self._remove_annotation_hover_line()
|
||
|
||
def _keypress(self, event):
|
||
"""Handle keypress events."""
|
||
key = event.key
|
||
n_channels = self.mne.n_channels
|
||
if self.mne.is_epochs:
|
||
last_time = self.mne.n_times / self.mne.info["sfreq"]
|
||
else:
|
||
last_time = self.mne.inst.times[-1]
|
||
# scroll up/down
|
||
if key in ("down", "up", "shift+down", "shift+up"):
|
||
key = key.split("+")[-1]
|
||
direction = -1 if key == "up" else 1
|
||
# butterfly case
|
||
if self.mne.butterfly:
|
||
return
|
||
# group_by case
|
||
elif self.mne.fig_selection is not None:
|
||
buttons = self.mne.fig_selection.mne.radio_ax.buttons
|
||
labels = [label.get_text() for label in buttons.labels]
|
||
current_label = buttons.value_selected
|
||
current_idx = labels.index(current_label)
|
||
selections_dict = self.mne.ch_selections
|
||
penult = current_idx < (len(labels) - 1)
|
||
pre_penult = current_idx < (len(labels) - 2)
|
||
has_custom = selections_dict.get("Custom", None) is not None
|
||
def_custom = len(selections_dict.get("Custom", list()))
|
||
up_ok = key == "up" and current_idx > 0
|
||
down_ok = key == "down" and (
|
||
pre_penult
|
||
or (penult and not has_custom)
|
||
or (penult and has_custom and def_custom)
|
||
)
|
||
if up_ok or down_ok:
|
||
buttons.set_active(current_idx + direction)
|
||
# normal case
|
||
else:
|
||
ceiling = len(self.mne.ch_order) - n_channels
|
||
ch_start = self.mne.ch_start + direction * n_channels
|
||
self.mne.ch_start = np.clip(ch_start, 0, ceiling)
|
||
self._update_picks()
|
||
self._update_vscroll()
|
||
self._redraw()
|
||
# scroll left/right
|
||
elif key in ("right", "left", "shift+right", "shift+left"):
|
||
old_t_start = self.mne.t_start
|
||
direction = 1 if key.endswith("right") else -1
|
||
if self.mne.is_epochs:
|
||
denom = 1 if key.startswith("shift") else self.mne.n_epochs
|
||
else:
|
||
denom = 1 if key.startswith("shift") else 4
|
||
t_max = last_time - self.mne.duration
|
||
t_start = self.mne.t_start + direction * self.mne.duration / denom
|
||
self.mne.t_start = np.clip(t_start, self.mne.first_time, t_max)
|
||
if self.mne.t_start != old_t_start:
|
||
self._update_hscroll()
|
||
self._redraw(annotations=True)
|
||
# scale traces
|
||
elif key in ("=", "+", "-"):
|
||
scaler = 1 / 1.1 if key == "-" else 1.1
|
||
self.mne.scale_factor *= scaler
|
||
self._redraw(update_data=False)
|
||
# change number of visible channels
|
||
elif (
|
||
key in ("pageup", "pagedown")
|
||
and self.mne.fig_selection is None
|
||
and not self.mne.butterfly
|
||
):
|
||
new_n_ch = n_channels + (1 if key == "pageup" else -1)
|
||
self.mne.n_channels = np.clip(new_n_ch, 1, len(self.mne.ch_order))
|
||
# add new chs from above if we're at the bottom of the scrollbar
|
||
ch_end = self.mne.ch_start + self.mne.n_channels
|
||
if ch_end > len(self.mne.ch_order) and self.mne.ch_start > 0:
|
||
self.mne.ch_start -= 1
|
||
self._update_vscroll()
|
||
# redraw only if changed
|
||
if self.mne.n_channels != n_channels:
|
||
self._update_picks()
|
||
self._update_trace_offsets()
|
||
self._redraw(annotations=True)
|
||
# change duration
|
||
elif key in ("home", "end"):
|
||
old_dur = self.mne.duration
|
||
dur_delta = 1 if key == "end" else -1
|
||
if self.mne.is_epochs:
|
||
# prevent from showing zero epochs, or more epochs than we have
|
||
self.mne.n_epochs = np.clip(
|
||
self.mne.n_epochs + dur_delta, 1, len(self.mne.inst)
|
||
)
|
||
# use the length of one epoch as duration change
|
||
min_dur = len(self.mne.inst.times) / self.mne.info["sfreq"]
|
||
new_dur = self.mne.duration + dur_delta * min_dur
|
||
else:
|
||
# never show fewer than 3 samples
|
||
min_dur = 3 * np.diff(self.mne.inst.times[:2])[0]
|
||
# use multiplicative dur_delta
|
||
dur_delta = 5 / 4 if dur_delta > 0 else 4 / 5
|
||
new_dur = self.mne.duration * dur_delta
|
||
self.mne.duration = np.clip(new_dur, min_dur, last_time)
|
||
if self.mne.duration != old_dur:
|
||
if self.mne.t_start + self.mne.duration > last_time:
|
||
self.mne.t_start = last_time - self.mne.duration
|
||
self._update_hscroll()
|
||
self._redraw(annotations=True)
|
||
elif key == "?": # help window
|
||
self._toggle_help_fig(event)
|
||
elif key == "a": # annotation mode
|
||
self._toggle_annotation_fig()
|
||
elif key == "b" and self.mne.instance_type != "ica": # butterfly mode
|
||
self._toggle_butterfly()
|
||
elif key == "d": # DC shift
|
||
self.mne.remove_dc = not self.mne.remove_dc
|
||
self._redraw()
|
||
elif key == "h": # histogram
|
||
self._toggle_epoch_histogram()
|
||
elif key == "j" and len(self.mne.projs): # SSP window
|
||
self._toggle_proj_fig()
|
||
elif key == "J" and len(self.mne.projs):
|
||
self._toggle_proj_checkbox(event, toggle_all=True)
|
||
elif key == "p": # toggle draggable annotations
|
||
self._toggle_draggable_annotations(event)
|
||
if self.mne.fig_annotation is not None:
|
||
checkbox = self.mne.fig_annotation.mne.drag_checkbox
|
||
with _events_off(checkbox):
|
||
checkbox.set_active(0)
|
||
elif key == "s": # scalebars
|
||
self._toggle_scalebars(event)
|
||
elif key == "w": # toggle noise cov whitening
|
||
self._toggle_whitening()
|
||
elif key == "z": # zen mode: hide scrollbars and buttons
|
||
self._toggle_scrollbars()
|
||
self._redraw(update_data=False)
|
||
elif key == "t":
|
||
self._toggle_time_format()
|
||
else: # check for close key / fullscreen toggle
|
||
super()._keypress(event)
|
||
|
||
def _buttonpress(self, event):
|
||
"""Handle mouse clicks."""
|
||
from matplotlib.collections import PolyCollection
|
||
|
||
from ..annotations import _sync_onset
|
||
|
||
butterfly = self.mne.butterfly
|
||
annotating = self.mne.fig_annotation is not None
|
||
ax_main = self.mne.ax_main
|
||
inst = self.mne.inst
|
||
# ignore middle clicks, scroll wheel events, and clicks outside axes
|
||
if event.button not in (1, 3) or event.inaxes is None:
|
||
return
|
||
elif event.button == 1: # left-click (primary)
|
||
# click in main axes
|
||
if event.inaxes == ax_main and not annotating:
|
||
if self.mne.instance_type == "epochs" or not butterfly:
|
||
for line in self.mne.traces + self.mne.epoch_traces:
|
||
if line.contains(event)[0]:
|
||
if self.mne.instance_type == "epochs":
|
||
self._toggle_bad_epoch(event)
|
||
else:
|
||
idx = self.mne.traces.index(line)
|
||
self._toggle_bad_channel(idx)
|
||
return
|
||
self._show_vline(event.xdata) # butterfly / not on data trace
|
||
self._redraw(update_data=False, annotations=False)
|
||
return
|
||
# click in vertical scrollbar
|
||
elif event.inaxes == self.mne.ax_vscroll:
|
||
if self.mne.fig_selection is not None:
|
||
self._change_selection_vscroll(event)
|
||
elif self._check_update_vscroll_clicked(event):
|
||
self._redraw()
|
||
# click in horizontal scrollbar
|
||
elif event.inaxes == self.mne.ax_hscroll:
|
||
if self._check_update_hscroll_clicked(event):
|
||
self._redraw(annotations=True)
|
||
# click on proj button
|
||
elif event.inaxes == self.mne.ax_proj:
|
||
self._toggle_proj_fig(event)
|
||
# click on help button
|
||
elif event.inaxes == self.mne.ax_help:
|
||
self._toggle_help_fig(event)
|
||
else: # right-click (secondary)
|
||
if annotating:
|
||
spans = [
|
||
span
|
||
for span in ax_main.collections
|
||
if isinstance(span, PolyCollection)
|
||
]
|
||
if any(span.contains(event)[0] for span in spans):
|
||
xdata = event.xdata - self.mne.first_time
|
||
start = _sync_onset(inst, inst.annotations.onset)
|
||
end = start + inst.annotations.duration
|
||
is_onscreen = self.mne.onscreen_annotations # boolean array
|
||
was_clicked = (xdata > start) & (xdata < end) & is_onscreen
|
||
# determine which annotation label is "selected"
|
||
buttons = self.mne.fig_annotation.mne.radio_ax.buttons
|
||
current_label = buttons.value_selected
|
||
is_active_label = inst.annotations.description == current_label
|
||
# use z-order as tiebreaker (or if click wasn't on an active span)
|
||
# (ax_main.collections only includes *visible* annots, so we offset)
|
||
visible_zorders = [span.zorder for span in spans]
|
||
zorders = np.zeros_like(is_onscreen).astype(int)
|
||
offset = np.where(is_onscreen)[0][0]
|
||
zorders[offset : (offset + len(visible_zorders))] = visible_zorders
|
||
# among overlapping clicked spans, prefer removing spans whose label
|
||
# is the active label; then fall back to zorder as deciding factor
|
||
active_clicked = was_clicked & is_active_label
|
||
mask = active_clicked if any(active_clicked) else was_clicked
|
||
highest = zorders == zorders[mask].max()
|
||
idx = np.where(highest)[0]
|
||
inst.annotations.delete(idx)
|
||
self._remove_annotation_hover_line()
|
||
self._draw_annotations()
|
||
self.canvas.draw_idle()
|
||
elif event.inaxes == ax_main:
|
||
self._toggle_vline(False)
|
||
|
||
def _pick(self, event):
|
||
"""Handle matplotlib pick events."""
|
||
from matplotlib.text import Text
|
||
|
||
if self.mne.butterfly:
|
||
return
|
||
# clicked on channel name
|
||
if isinstance(event.artist, Text):
|
||
ch_name = event.artist.get_text()
|
||
ind = self.mne.ch_names[self.mne.picks].tolist().index(ch_name)
|
||
if event.mouseevent.button == 1: # left click
|
||
self._toggle_bad_channel(ind)
|
||
elif event.mouseevent.button == 3: # right click
|
||
self._create_ch_context_fig(ind)
|
||
|
||
def _create_ch_context_fig(self, idx):
|
||
fig = super()._create_ch_context_fig(idx)
|
||
plt_show(fig=fig)
|
||
|
||
def _new_child_figure(self, fig_name, *, layout=None, **kwargs):
|
||
"""Instantiate a new MNE dialog figure (with event listeners)."""
|
||
fig = _figure(
|
||
toolbar=False,
|
||
parent_fig=self,
|
||
fig_name=fig_name,
|
||
layout=layout,
|
||
**kwargs,
|
||
)
|
||
fig._add_default_callbacks()
|
||
self.mne.child_figs.append(fig)
|
||
if isinstance(fig_name, str):
|
||
setattr(self.mne, fig_name, fig)
|
||
return fig
|
||
|
||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||
# HELP DIALOG
|
||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||
|
||
def _create_help_fig(self):
|
||
"""Create help dialog window."""
|
||
text = {
|
||
key: val for key, val in self._get_help_text().items() if val is not None
|
||
}
|
||
keys = ""
|
||
vals = ""
|
||
for key, val in text.items():
|
||
newsection = "\n" if key.startswith("_") else ""
|
||
key = key[1:] if key.startswith("_") else key
|
||
newlines = "\n" * len(val.split("\n")) # handle multiline values
|
||
keys += f"{newsection}{key} {newlines}"
|
||
vals += f"{newsection}{val}\n"
|
||
# calc figure size
|
||
n_lines = len(keys.split("\n"))
|
||
longest_key = max(len(k) for k in text.keys())
|
||
longest_val = max(
|
||
max(len(w) for w in v.split("\n")) if "\n" in v else len(v)
|
||
for v in text.values()
|
||
)
|
||
width = (longest_key + longest_val) / 12
|
||
height = (n_lines) / 5
|
||
# create figure and axes
|
||
fig = self._new_child_figure(
|
||
figsize=(width, height), fig_name="fig_help", window_title="Help"
|
||
)
|
||
ax = fig.add_axes((0.01, 0.01, 0.98, 0.98))
|
||
ax.set_axis_off()
|
||
kwargs = dict(va="top", linespacing=1.5, usetex=False)
|
||
ax.text(0.42, 1, keys, ma="right", ha="right", **kwargs)
|
||
ax.text(0.42, 1, vals, ma="left", ha="left", **kwargs)
|
||
|
||
def _toggle_help_fig(self, event):
|
||
"""Show/hide the help dialog window."""
|
||
if self.mne.fig_help is None:
|
||
self._create_help_fig()
|
||
plt_show(fig=self.mne.fig_help)
|
||
else:
|
||
plt.close(self.mne.fig_help)
|
||
|
||
def _get_help_text(self):
|
||
"""Generate help dialog text; `None`-valued entries removed later."""
|
||
inst = self.mne.instance_type
|
||
is_raw = inst == "raw"
|
||
is_epo = inst == "epochs"
|
||
is_ica = inst == "ica"
|
||
has_proj = bool(len(self.mne.projs))
|
||
# adapt keys to different platforms
|
||
is_mac = platform.system() == "Darwin"
|
||
dur_keys = ("fn + ←", "fn + →") if is_mac else ("Home", "End")
|
||
ch_keys = ("fn + ↑", "fn + ↓") if is_mac else ("Page up", "Page down")
|
||
# adapt descriptions to different instance types
|
||
ch_cmp = "component" if is_ica else "channel"
|
||
ch_epo = "epoch" if is_epo else "channel"
|
||
ica_bad = "Mark/unmark component for exclusion"
|
||
dur_vals = (
|
||
[f"Show {n} epochs" for n in ("fewer", "more")]
|
||
if self.mne.is_epochs
|
||
else [f"Show {d} time window" for d in ("shorter", "longer")]
|
||
)
|
||
ch_vals = [
|
||
f"{inc_dec} number of visible {ch_cmp}s"
|
||
for inc_dec in ("Increase", "Decrease")
|
||
]
|
||
lclick_data = ica_bad if is_ica else f"Mark/unmark bad {ch_epo}"
|
||
lclick_name = ica_bad if is_ica else "Mark/unmark bad channel"
|
||
rclick_name = dict(
|
||
ica="Show diagnostics for component",
|
||
epochs="Show imageplot for channel",
|
||
raw="Show channel location",
|
||
)[inst]
|
||
# TODO not yet implemented
|
||
# ldrag = ('Show spectrum plot for selected time span;\nor (in '
|
||
# 'annotation mode) add annotation') if inst== 'raw' else None
|
||
ldrag = "add annotation (in annotation mode)" if is_raw else None
|
||
noise_cov = None if self.mne.noise_cov is None else "Toggle signal whitening"
|
||
scrl = "1 epoch" if self.mne.is_epochs else "¼ window"
|
||
# below, value " " is a hack to make "\n".split(value) have length 1
|
||
help_text = OrderedDict(
|
||
[
|
||
("_NAVIGATION", " "),
|
||
("→", f"Scroll {scrl} right (scroll full window with Shift + →)"),
|
||
("←", f"Scroll {scrl} left (scroll full window with Shift + ←)"),
|
||
(dur_keys[0], dur_vals[0]),
|
||
(dur_keys[1], dur_vals[1]),
|
||
("↑", f"Scroll up ({ch_cmp}s)"),
|
||
("↓", f"Scroll down ({ch_cmp}s)"),
|
||
(ch_keys[0], ch_vals[0]),
|
||
(ch_keys[1], ch_vals[1]),
|
||
("_SIGNAL TRANSFORMATIONS", " "),
|
||
("+ or =", "Increase signal scaling"),
|
||
("-", "Decrease signal scaling"),
|
||
("b", "Toggle butterfly mode" if not is_ica else None),
|
||
("d", "Toggle DC removal" if is_raw else None),
|
||
("w", noise_cov),
|
||
("_USER INTERFACE", " "),
|
||
("a", "Toggle annotation mode" if is_raw else None),
|
||
("h", "Toggle peak-to-peak histogram" if is_epo else None),
|
||
("j", "Toggle SSP projector window" if has_proj else None),
|
||
("shift+j", "Toggle all SSPs"),
|
||
("p", "Toggle draggable annotations" if is_raw else None),
|
||
("s", "Toggle scalebars" if not is_ica else None),
|
||
("z", "Toggle scrollbars"),
|
||
("t", "Toggle time format" if not is_epo else None),
|
||
("F11", "Toggle fullscreen" if not is_mac else None),
|
||
("?", "Open this help window"),
|
||
("esc", "Close focused figure or dialog window"),
|
||
("_MOUSE INTERACTION", " "),
|
||
(f"Left-click {ch_cmp} name", lclick_name),
|
||
(f"Left-click {ch_cmp} data", lclick_data),
|
||
("Left-click-and-drag on plot", ldrag),
|
||
("Left-click on plot background", "Place vertical guide"),
|
||
("Right-click on plot background", "Clear vertical guide"),
|
||
("Right-click on channel name", rclick_name),
|
||
]
|
||
)
|
||
return help_text
|
||
|
||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||
# ANNOTATIONS
|
||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||
|
||
def _create_annotation_fig(self):
|
||
"""Create the annotation dialog window."""
|
||
from matplotlib.widgets import Button, CheckButtons, SpanSelector
|
||
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
|
||
from mpl_toolkits.axes_grid1.axes_size import Fixed
|
||
|
||
# make figure
|
||
labels = np.array(sorted(set(self.mne.inst.annotations.description)))
|
||
radio_button_h = self._compute_annotation_figsize(len(labels))
|
||
figsize = (ANNOTATION_FIG_W, ANNOTATION_FIG_MIN_H + radio_button_h)
|
||
fig = self._new_child_figure(
|
||
figsize=figsize,
|
||
FigureClass=MNEAnnotationFigure,
|
||
fig_name="fig_annotation",
|
||
window_title="Annotations",
|
||
)
|
||
# make main axes
|
||
left = fig._inch_to_rel(ANNOTATION_FIG_PAD)
|
||
bottom = fig._inch_to_rel(ANNOTATION_FIG_PAD, horiz=False)
|
||
width = 1 - 2 * left
|
||
height = 1 - 2 * bottom
|
||
fig.mne.radio_ax = fig.add_axes(
|
||
(left, bottom, width, height), frame_on=False, aspect="equal"
|
||
)
|
||
div = make_axes_locatable(fig.mne.radio_ax)
|
||
# append show/hide checkboxes at right
|
||
fig.mne.show_hide_ax = div.append_axes(
|
||
position="right",
|
||
size=Fixed(ANNOTATION_FIG_CHECKBOX_COLUMN_W),
|
||
pad=Fixed(ANNOTATION_FIG_PAD),
|
||
aspect="equal",
|
||
sharey=fig.mne.radio_ax,
|
||
)
|
||
# populate w/ radio buttons & labels
|
||
self._update_annotation_fig()
|
||
# append instructions at top
|
||
instructions_ax = div.append_axes(
|
||
position="top", size=Fixed(1), pad=Fixed(5 * ANNOTATION_FIG_PAD)
|
||
)
|
||
instructions = "\n".join(
|
||
[
|
||
r"$\mathbf{Left‐click~&~drag~on~plot:}$ create/modify annotation",
|
||
r"$\mathbf{Right‐click~on~plot~annotation:}$ delete annotation",
|
||
r"$\mathbf{Type~in~annotation~window:}$ modify new label name",
|
||
r"$\mathbf{Enter~(or~click~button):}$ add new label to list",
|
||
r"$\mathbf{Esc:}$ exit annotation mode & close this window",
|
||
]
|
||
)
|
||
instructions_ax.text(
|
||
0, 1, instructions, va="top", ha="left", linespacing=1.7, usetex=False
|
||
) # force use of MPL mathtext parser
|
||
instructions_ax.set_axis_off()
|
||
# append text entry axes at bottom
|
||
text_entry_ax = div.append_axes(
|
||
position="bottom",
|
||
size=Fixed(3 * ANNOTATION_FIG_PAD),
|
||
pad=Fixed(ANNOTATION_FIG_PAD),
|
||
)
|
||
text_entry_ax.text(
|
||
0.4, 0.5, "New label:", va="center", ha="right", weight="bold"
|
||
)
|
||
fig.label = text_entry_ax.text(0.5, 0.5, "BAD_", va="center", ha="left")
|
||
text_entry_ax.set_axis_off()
|
||
# append button at bottom
|
||
button_ax = div.append_axes(
|
||
position="bottom",
|
||
size=Fixed(3 * ANNOTATION_FIG_PAD),
|
||
pad=Fixed(ANNOTATION_FIG_PAD),
|
||
)
|
||
fig.button = Button(button_ax, "Add new label")
|
||
fig.button.on_clicked(self._add_annotation_label)
|
||
plt_show(fig=fig)
|
||
# add "draggable" checkbox
|
||
drag_ax_height = 3 * ANNOTATION_FIG_PAD
|
||
drag_ax = div.append_axes(
|
||
"bottom", size=Fixed(drag_ax_height), pad=Fixed(ANNOTATION_FIG_PAD)
|
||
)
|
||
check_kwargs = _get_check_kwargs()
|
||
checkbox = CheckButtons(
|
||
drag_ax,
|
||
labels=("Draggable edges?",),
|
||
actives=(self.mne.draggable_annotations,),
|
||
**check_kwargs,
|
||
)
|
||
checkbox.on_clicked(self._toggle_draggable_annotations)
|
||
fig.mne.drag_checkbox = checkbox
|
||
# reposition & resize axes
|
||
width_in, _ = fig.get_size_inches()
|
||
width_ax = fig._inch_to_rel(
|
||
width_in - ANNOTATION_FIG_CHECKBOX_COLUMN_W - 3 * ANNOTATION_FIG_PAD
|
||
)
|
||
aspect = width_ax / fig._inch_to_rel(drag_ax_height)
|
||
drag_ax.set(xlim=(0, aspect), ylim=(0, 1))
|
||
drag_ax.set_axis_off()
|
||
if _OLD_BUTTONS:
|
||
rect = checkbox.rectangles[0]
|
||
_pad, _size = (0.2, 0.6)
|
||
rect.set_bounds(_pad, _pad, _size, _size)
|
||
lines = checkbox.lines[0]
|
||
for line, direction in zip(lines, (1, -1)):
|
||
line.set_xdata((_pad, _pad + _size)[::direction])
|
||
line.set_ydata((_pad, _pad + _size))
|
||
text = checkbox.labels[0]
|
||
text.set(position=(3 * _pad + _size, 0.45), va="center")
|
||
for artist in lines + (rect, text):
|
||
artist.set_transform(drag_ax.transData)
|
||
# setup interactivity in plot window
|
||
if fig.mne.radio_ax.buttons is None:
|
||
col = "#ff0000"
|
||
else:
|
||
col = self.mne.annotation_segment_colors[self._get_annotation_labels()[0]]
|
||
|
||
selector = SpanSelector(
|
||
self.mne.ax_main,
|
||
self._select_annotation_span,
|
||
"horizontal",
|
||
minspan=0.1,
|
||
useblit=True,
|
||
button=1,
|
||
props=dict(alpha=0.5, facecolor=col),
|
||
)
|
||
self.mne.ax_main.selector = selector
|
||
self.mne._callback_ids["motion_notify_event"] = self.canvas.mpl_connect(
|
||
"motion_notify_event", self._hover
|
||
)
|
||
|
||
def _toggle_visible_annotations(self, event):
|
||
"""Enable/disable display of annotations on a per-label basis."""
|
||
checkboxes = self.mne.show_hide_annotation_checkboxes
|
||
labels = [t.get_text() for t in checkboxes.labels]
|
||
actives = checkboxes.get_status()
|
||
self.mne.visible_annotations = dict(zip(labels, actives))
|
||
self._redraw(update_data=False, annotations=True)
|
||
|
||
def _toggle_draggable_annotations(self, event):
|
||
"""Enable/disable draggable annotation edges."""
|
||
self.mne.draggable_annotations = not self.mne.draggable_annotations
|
||
|
||
def _update_annotation_fig(self, *, draw=True):
|
||
"""Draw or redraw the radio buttons and annotation labels."""
|
||
from matplotlib.colors import to_rgba
|
||
from matplotlib.widgets import CheckButtons, RadioButtons
|
||
|
||
# define shorthand variables
|
||
fig = self.mne.fig_annotation
|
||
ax = fig.mne.radio_ax
|
||
labels = self._get_annotation_labels()
|
||
# compute new figsize
|
||
radio_button_h = self._compute_annotation_figsize(len(labels))
|
||
fig.set_size_inches(
|
||
ANNOTATION_FIG_W, ANNOTATION_FIG_MIN_H + radio_button_h, forward=True
|
||
)
|
||
# populate center axes with labels & radio buttons
|
||
ax.clear()
|
||
title = "Existing labels:" if len(labels) else "No existing labels"
|
||
ax.set_title(title, size=None, loc="left")
|
||
if len(labels):
|
||
if _OLD_BUTTONS:
|
||
ax.buttons = RadioButtons(ax, labels)
|
||
radius = 0.15
|
||
circles = ax.buttons.circles
|
||
for circle, label in zip(circles, ax.buttons.labels):
|
||
circle.set_transform(ax.transData)
|
||
center = ax.transData.inverted().transform(
|
||
ax.transAxes.transform((0.1, 0))
|
||
)
|
||
circle.set_center((center[0], circle.center[1]))
|
||
circle.set_edgecolor(
|
||
self.mne.annotation_segment_colors[label.get_text()]
|
||
)
|
||
circle.set_linewidth(4)
|
||
circle.set_radius(radius / len(labels))
|
||
else:
|
||
edgecolors = [
|
||
self.mne.annotation_segment_colors[label] for label in labels
|
||
]
|
||
facecolors = [to_rgba(col)[:3] + (0.5,) for col in edgecolors]
|
||
radio_props = dict(
|
||
s=144,
|
||
linewidth=4,
|
||
edgecolor=edgecolors,
|
||
facecolor=facecolors,
|
||
)
|
||
ax.buttons = RadioButtons(ax, labels, radio_props=radio_props)
|
||
else:
|
||
ax.buttons = None
|
||
# adjust xlim to keep equal aspect & full width (keep circles round)
|
||
aspect = (
|
||
ANNOTATION_FIG_W - ANNOTATION_FIG_CHECKBOX_COLUMN_W - 3 * ANNOTATION_FIG_PAD
|
||
) / radio_button_h
|
||
ax.set_xlim((0, aspect))
|
||
# style the selected button
|
||
if len(labels):
|
||
fig._set_active_button(0, draw=False)
|
||
# add event listeners
|
||
if ax.buttons is not None:
|
||
if _OLD_BUTTONS:
|
||
ax.buttons.disconnect_events() # clear MPL default listeners
|
||
ax.buttons.on_clicked(fig._radiopress)
|
||
if _OLD_BUTTONS:
|
||
ax.buttons.connect_event("button_press_event", fig._click_override)
|
||
ax.set_axis_off()
|
||
|
||
# now do the show/hide checkboxes
|
||
show_hide_ax = fig.mne.show_hide_ax
|
||
show_hide_ax.clear()
|
||
show_hide_ax.set_axis_on()
|
||
aspect = ANNOTATION_FIG_CHECKBOX_COLUMN_W / radio_button_h
|
||
show_hide_ax.set(xlim=(0, aspect), ylim=(0, 1))
|
||
# ensure new labels have checkbox values
|
||
check_values = {label: False for label in labels}
|
||
check_values.update(self.mne.visible_annotations) # existing checks
|
||
actives = [check_values[label] for label in labels]
|
||
# regenerate checkboxes
|
||
check_kwargs = _get_check_kwargs()
|
||
checkboxes = CheckButtons(
|
||
ax=fig.mne.show_hide_ax, labels=labels, actives=actives, **check_kwargs
|
||
)
|
||
checkboxes.on_clicked(self._toggle_visible_annotations)
|
||
# add title, hide labels
|
||
show_hide_title = "show/\nhide " if len(labels) else ""
|
||
show_hide_ax.set_title(show_hide_title, size=None, loc="right")
|
||
for label in checkboxes.labels:
|
||
label.set_visible(False)
|
||
show_hide_ax.set_axis_off()
|
||
# fix aspect and right-align
|
||
if _OLD_BUTTONS:
|
||
if len(labels) == 1:
|
||
bounds = (0.05, 0.375, 0.25, 0.25) # undo MPL special case
|
||
checkboxes.rectangles[0].set_bounds(bounds)
|
||
for line, step in zip(checkboxes.lines[0], (1, -1)):
|
||
line.set_xdata((bounds[0], bounds[0] + bounds[2]))
|
||
line.set_ydata((bounds[1], bounds[1] + bounds[3])[::step])
|
||
for rect in checkboxes.rectangles:
|
||
rect.set_transform(show_hide_ax.transData)
|
||
bbox = rect.get_bbox()
|
||
bounds = (aspect, bbox.ymin, -bbox.width, bbox.height)
|
||
rect.set_bounds(bounds)
|
||
rect.set_clip_on(False)
|
||
for line in np.array(checkboxes.lines).ravel():
|
||
line.set_transform(show_hide_ax.transData)
|
||
line.set_xdata(aspect + 0.05 - np.array(line.get_xdata()))
|
||
# store state
|
||
self.mne.visible_annotations = check_values
|
||
self.mne.show_hide_annotation_checkboxes = checkboxes
|
||
if draw:
|
||
fig.canvas.draw_idle()
|
||
|
||
def _toggle_annotation_fig(self):
|
||
"""Show/hide the annotation dialog window."""
|
||
if self.mne.fig_annotation is None and not self.mne.is_epochs:
|
||
self._create_annotation_fig()
|
||
else:
|
||
plt.close(self.mne.fig_annotation)
|
||
|
||
def _compute_annotation_figsize(self, n_labels):
|
||
"""Adapt size of Annotation UI to accommodate the number of buttons.
|
||
|
||
self._create_annotation_fig() implements the following:
|
||
|
||
Fixed part of height:
|
||
0.1 top margin
|
||
1.0 instructions
|
||
0.5 padding below instructions
|
||
--- (variable-height axis for label list, returned by this method)
|
||
0.1 padding above text entry
|
||
0.3 text entry
|
||
0.1 padding above button
|
||
0.3 button
|
||
0.1 padding above checkbox
|
||
0.3 checkbox
|
||
0.1 bottom margin
|
||
------------------------------------------
|
||
2.9 total fixed height
|
||
"""
|
||
return max(ANNOTATION_FIG_PAD, 0.7 * n_labels)
|
||
|
||
def _add_annotation_label(self, event):
|
||
"""Add new annotation description."""
|
||
text = self.mne.fig_annotation.label.get_text()
|
||
# If it exists, change this title. If it doesn't, the title will
|
||
# be set in _update_annotation_fig()
|
||
if text in self.mne.new_annotation_labels:
|
||
self.mne.fig_annotation.mne.radio_ax.set_title(
|
||
f"Existing labels: (duplicate label: {repr(text)})",
|
||
size=None,
|
||
loc="left",
|
||
)
|
||
self.mne.fig_annotation.canvas.draw()
|
||
return
|
||
self.mne.new_annotation_labels.append(text)
|
||
self._setup_annotation_colors()
|
||
self._update_annotation_fig(draw=False)
|
||
# automatically activate new label's radio button
|
||
idx = [
|
||
label.get_text()
|
||
for label in self.mne.fig_annotation.mne.radio_ax.buttons.labels
|
||
].index(text)
|
||
self.mne.fig_annotation._set_active_button(idx, draw=False)
|
||
# simulate a click on the radiobutton → update the span selector color
|
||
self.mne.fig_annotation._radiopress(event=None, draw=False)
|
||
# reset the text entry box's text
|
||
self.mne.fig_annotation.label.set_text("BAD_")
|
||
self.mne.fig_annotation.canvas.draw()
|
||
|
||
def _select_annotation_span(self, vmin, vmax):
|
||
"""Handle annotation span selector."""
|
||
from ..annotations import _sync_onset
|
||
|
||
onset = _sync_onset(self.mne.inst, vmin, True) - self.mne.first_time
|
||
duration = vmax - vmin
|
||
buttons = self.mne.fig_annotation.mne.radio_ax.buttons
|
||
if buttons is None or buttons.value_selected is None:
|
||
logger.warning(
|
||
"No annotation-label exists! "
|
||
"Add one by typing the name and clicking "
|
||
'on "Add new label" in the annotation-dialog.'
|
||
)
|
||
else:
|
||
labels = [label.get_text() for label in buttons.labels]
|
||
active_idx = labels.index(buttons.value_selected)
|
||
_merge_annotations(
|
||
onset, onset + duration, labels[active_idx], self.mne.inst.annotations
|
||
)
|
||
# if adding a span with an annotation label that is hidden, show it
|
||
if not self.mne.visible_annotations[buttons.value_selected]:
|
||
self.mne.show_hide_annotation_checkboxes.set_active(active_idx)
|
||
self._redraw(update_data=False, annotations=True)
|
||
|
||
def _remove_annotation_hover_line(self):
|
||
"""Remove annotation line from the plot and reactivate selector."""
|
||
if self.mne.annotation_hover_line is not None:
|
||
self.mne.annotation_hover_line.remove()
|
||
self.mne.annotation_hover_line = None
|
||
self.mne.ax_main.selector.active = True
|
||
self.canvas.draw()
|
||
|
||
def _modify_annotation(self, old_x, new_x):
|
||
"""Modify annotation."""
|
||
from ..annotations import _sync_onset
|
||
|
||
segment = np.array(np.where(self.mne.annotation_segments == old_x))
|
||
if segment.shape[1] == 0:
|
||
return
|
||
raw = self.mne.inst
|
||
annotations = raw.annotations
|
||
first_time = self.mne.first_time
|
||
idx = [segment[0][0], segment[1][0]]
|
||
onset = _sync_onset(raw, self.mne.annotation_segments[idx[0]][0], True)
|
||
ann_idx = np.where(annotations.onset == onset - first_time)[0]
|
||
if idx[1] == 0: # start of annotation
|
||
onset = _sync_onset(raw, new_x, True) - first_time
|
||
duration = annotations.duration[ann_idx] + old_x - new_x
|
||
else: # end of annotation
|
||
onset = annotations.onset[ann_idx]
|
||
duration = _sync_onset(raw, new_x, True) - onset - first_time
|
||
if duration < 0:
|
||
onset += duration
|
||
duration *= -1.0
|
||
_merge_annotations(
|
||
onset,
|
||
onset + duration,
|
||
annotations.description[ann_idx],
|
||
annotations,
|
||
ann_idx,
|
||
)
|
||
self._draw_annotations()
|
||
self._remove_annotation_hover_line()
|
||
self.canvas.draw_idle()
|
||
|
||
def _clear_annotations(self):
|
||
"""Clear all annotations from the figure."""
|
||
for annot in list(self.mne.annotations):
|
||
annot.remove()
|
||
self.mne.annotations.remove(annot)
|
||
for annot in list(self.mne.hscroll_annotations):
|
||
annot.remove()
|
||
self.mne.hscroll_annotations.remove(annot)
|
||
for text in list(self.mne.annotation_texts):
|
||
text.remove()
|
||
self.mne.annotation_texts.remove(text)
|
||
|
||
def _draw_annotations(self):
|
||
"""Draw (or redraw) the annotation spans."""
|
||
self._clear_annotations()
|
||
self._update_annotation_segments()
|
||
segments = self.mne.annotation_segments
|
||
onscreen_annotations = np.zeros(len(segments), dtype=bool)
|
||
times = self.mne.times
|
||
ax = self.mne.ax_main
|
||
ylim = ax.get_ylim()
|
||
for idx, (start, end) in enumerate(segments):
|
||
descr = self.mne.inst.annotations.description[idx]
|
||
segment_color = self.mne.annotation_segment_colors[descr]
|
||
zorder = self.mne.zorder["ann"] + idx
|
||
kwargs = dict(color=segment_color, alpha=0.3, zorder=zorder)
|
||
if self.mne.visible_annotations[descr]:
|
||
# draw all segments on ax_hscroll
|
||
annot = self.mne.ax_hscroll.fill_betweenx((0, 1), start, end, **kwargs)
|
||
self.mne.hscroll_annotations.append(annot)
|
||
# draw only visible segments on ax_main
|
||
visible_segment = np.clip([start, end], times[0], times[-1])
|
||
if np.diff(visible_segment) > 0:
|
||
annot = ax.fill_betweenx(ylim, *visible_segment, **kwargs)
|
||
self.mne.annotations.append(annot)
|
||
onscreen_annotations[idx] = True
|
||
xy = (visible_segment.mean(), ylim[1])
|
||
text = ax.annotate(
|
||
descr,
|
||
xy,
|
||
xytext=(0, 9),
|
||
textcoords="offset points",
|
||
ha="center",
|
||
va="baseline",
|
||
color=segment_color,
|
||
)
|
||
self.mne.annotation_texts.append(text)
|
||
self.mne.onscreen_annotations = onscreen_annotations
|
||
|
||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||
# CHANNEL SELECTION GUI
|
||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||
|
||
def _create_selection_fig(self):
|
||
"""Create channel selection dialog window."""
|
||
from matplotlib.colors import to_rgb
|
||
from matplotlib.widgets import RadioButtons
|
||
|
||
# make figure
|
||
fig = self._new_child_figure(
|
||
figsize=(3, 7),
|
||
FigureClass=MNESelectionFigure,
|
||
fig_name="fig_selection",
|
||
window_title="Channel selection",
|
||
)
|
||
gs = fig.add_gridspec(15, 1)
|
||
# add sensor plot at top
|
||
fig.mne.sensor_ax = fig.add_subplot(gs[:5])
|
||
plot_sensors(
|
||
self.mne.info,
|
||
kind="select",
|
||
ch_type="all",
|
||
title="",
|
||
axes=fig.mne.sensor_ax,
|
||
ch_groups=self.mne.group_by,
|
||
show=False,
|
||
)
|
||
fig.subplots_adjust(bottom=0.01, top=0.99, left=0.01, right=0.99)
|
||
# style the sensors so the selection is easier to distinguish
|
||
fig.lasso.linewidth_selected = 2
|
||
self._update_highlighted_sensors()
|
||
# add radio button axes
|
||
radio_ax = fig.add_subplot(gs[5:-3], frame_on=False, aspect="equal")
|
||
fig.mne.radio_ax = radio_ax
|
||
selections_dict = self.mne.ch_selections
|
||
selections_dict.update(Custom=np.array([], dtype=int)) # for lasso
|
||
labels = list(selections_dict)
|
||
# make & style the radio buttons
|
||
activecolor = to_rgb(self.mne.fgcolor) + (0.5,)
|
||
radio_ax.buttons = RadioButtons(radio_ax, labels, activecolor=activecolor)
|
||
fig.mne.old_selection = 0
|
||
if _OLD_BUTTONS:
|
||
for circle in radio_ax.buttons.circles:
|
||
circle.set_radius(0.25 / len(labels))
|
||
circle.set_linewidth(2)
|
||
circle.set_edgecolor(self.mne.fgcolor)
|
||
fig._style_radio_buttons_butterfly()
|
||
# add instructions at bottom
|
||
instructions = (
|
||
"To use a custom selection, first click-drag on the sensor plot "
|
||
'to "lasso" the sensors you want to select, or hold Ctrl while '
|
||
"clicking individual sensors. Holding Ctrl while click-dragging "
|
||
"allows a lasso selection adding to (rather than replacing) the "
|
||
"existing selection."
|
||
)
|
||
instructions_ax = fig.add_subplot(gs[-3:], frame_on=False)
|
||
instructions_ax.text(
|
||
0.04, 0.08, instructions, va="bottom", ha="left", ma="left", wrap=True
|
||
)
|
||
instructions_ax.set_axis_off()
|
||
# add event listeners
|
||
radio_ax.buttons.on_clicked(fig._radiopress)
|
||
fig.lasso.callbacks.append(fig._set_custom_selection)
|
||
|
||
def _change_selection_vscroll(self, event):
|
||
"""Handle clicks on vertical scrollbar when using selections."""
|
||
buttons = self.mne.fig_selection.mne.radio_ax.buttons
|
||
labels = [label.get_text() for label in buttons.labels]
|
||
offset = 0
|
||
selections_dict = self.mne.ch_selections
|
||
for idx, label in enumerate(labels):
|
||
offset += len(selections_dict[label])
|
||
if event.ydata < offset:
|
||
with _events_off(buttons):
|
||
buttons.set_active(idx)
|
||
self.mne.fig_selection._radiopress(event)
|
||
return
|
||
|
||
def _update_selection(self):
|
||
"""Update visible channels based on selection dialog interaction."""
|
||
selections_dict = self.mne.ch_selections
|
||
fig = self.mne.fig_selection
|
||
buttons = fig.mne.radio_ax.buttons
|
||
label = buttons.value_selected
|
||
labels = [_label.get_text() for _label in buttons.labels]
|
||
self.mne.fig_selection.mne.old_selection = labels.index(label)
|
||
self.mne.picks = selections_dict[label]
|
||
self.mne.n_channels = len(self.mne.picks)
|
||
self._update_highlighted_sensors()
|
||
# if "Vertex" is defined, some channels appear twice, so if
|
||
# "Vertex" is selected, ch_start should be the *first* match;
|
||
# otherwise it should be the *last* match (since "Vertex" is
|
||
# always the first selection group, if it exists).
|
||
index = 0 if label == "Vertex" else -1
|
||
ch_order = np.concatenate(list(selections_dict.values()))
|
||
ch_start = np.where(ch_order == self.mne.picks[0])[0][index]
|
||
self.mne.ch_start = ch_start
|
||
self._update_trace_offsets()
|
||
self._update_vscroll()
|
||
self._redraw(annotations=True)
|
||
|
||
def _update_highlighted_sensors(self):
|
||
"""Update the sensor plot to show what is selected."""
|
||
inds = np.isin(
|
||
self.mne.fig_selection.lasso.ch_names, self.mne.ch_names[self.mne.picks]
|
||
).nonzero()[0]
|
||
self.mne.fig_selection.lasso.select_many(inds)
|
||
|
||
def _update_bad_sensors(self, pick, mark_bad):
|
||
"""Update the sensor plot to reflect (un)marked bad channels."""
|
||
# replicate plotting order from plot_sensors(), to get index right
|
||
sensor_picks = list()
|
||
ch_indices = channel_indices_by_type(self.mne.info)
|
||
for this_type in _DATA_CH_TYPES_SPLIT:
|
||
if this_type in self.mne.ch_types:
|
||
sensor_picks.extend(ch_indices[this_type])
|
||
sensor_idx = np.isin(sensor_picks, pick).nonzero()[0]
|
||
# change the sensor color
|
||
fig = self.mne.fig_selection
|
||
fig.lasso.ec[sensor_idx, 0] = float(mark_bad) # change R of RGBA array
|
||
fig.lasso.collection.set_edgecolors(fig.lasso.ec)
|
||
fig.canvas.draw_idle()
|
||
|
||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||
# PROJECTORS & BAD CHANNELS
|
||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||
|
||
def _create_proj_fig(self):
|
||
"""Create the projectors dialog window."""
|
||
from matplotlib.widgets import Button, CheckButtons
|
||
|
||
projs = self.mne.projs
|
||
labels = [p["desc"] for p in projs]
|
||
for ix, active in enumerate(self.mne.projs_active):
|
||
if active:
|
||
labels[ix] += " (already applied)"
|
||
# make figure
|
||
width = max([4.5, max([len(label) for label in labels]) / 8 + 0.5])
|
||
height = (len(projs) + 1) / 6 + 1.5
|
||
fig = self._new_child_figure(
|
||
figsize=(width, height),
|
||
fig_name="fig_proj",
|
||
window_title="SSP projection vectors",
|
||
)
|
||
# pass through some proj fig keypresses to the parent
|
||
fig.canvas.mpl_connect(
|
||
"key_press_event", lambda ev: self._keypress(ev) if ev.key in "jJ" else None
|
||
)
|
||
# make axes
|
||
offset = 1 / 6 / height
|
||
position = (0, offset, 1, 0.8 - offset)
|
||
ax = fig.add_axes(position, frame_on=False, aspect="equal")
|
||
# make title
|
||
first_line = (
|
||
"Projectors already applied to the data are dimmed.\n"
|
||
if any(self.mne.projs_active)
|
||
else ""
|
||
)
|
||
second_line = 'Projectors marked with "X" are active on the plot.'
|
||
ax.set_title(f"{first_line}{second_line}")
|
||
# draw checkboxes
|
||
checkboxes = CheckButtons(
|
||
ax,
|
||
labels=labels,
|
||
actives=self.mne.projs_on,
|
||
**_get_check_kwargs(labels=labels),
|
||
)
|
||
# gray-out already applied projectors
|
||
if _OLD_BUTTONS:
|
||
for label, rect, lines in zip(
|
||
checkboxes.labels, checkboxes.rectangles, checkboxes.lines
|
||
):
|
||
if label.get_text().endswith("(already applied)"):
|
||
label.set_color("0.5")
|
||
rect.set_edgecolor("0.7")
|
||
[x.set_color("0.7") for x in lines]
|
||
rect.set_linewidth(1)
|
||
# add "toggle all" button
|
||
ax_all = fig.add_axes((0.25, 0.01, 0.5, offset), frame_on=True)
|
||
fig.mne.proj_all = Button(ax_all, "Toggle all")
|
||
# add event listeners
|
||
checkboxes.on_clicked(self._toggle_proj_checkbox)
|
||
fig.mne.proj_all.on_clicked(
|
||
partial(self._toggle_proj_checkbox, toggle_all=True)
|
||
)
|
||
# save params
|
||
fig.mne.proj_checkboxes = checkboxes
|
||
# show figure
|
||
self.mne.fig_proj.canvas.draw()
|
||
plt_show(fig=self.mne.fig_proj, warn=False)
|
||
|
||
def _toggle_proj_fig(self, event=None):
|
||
"""Show/hide the projectors dialog window."""
|
||
if self.mne.fig_proj is None:
|
||
self._create_proj_fig()
|
||
else:
|
||
plt.close(self.mne.fig_proj)
|
||
|
||
def _toggle_proj_checkbox(self, event, toggle_all=False):
|
||
"""Perform operations when proj boxes clicked."""
|
||
on = self.mne.projs_on
|
||
applied = self.mne.projs_active
|
||
fig = self.mne.fig_proj
|
||
new_state = (
|
||
np.full_like(on, not all(on))
|
||
if toggle_all
|
||
else np.array(fig.mne.proj_checkboxes.get_status())
|
||
)
|
||
# update Xs when toggling all
|
||
if fig is not None:
|
||
if toggle_all:
|
||
with _events_off(fig.mne.proj_checkboxes):
|
||
for ix in np.where(on != new_state)[0]:
|
||
fig.mne.proj_checkboxes.set_active(ix)
|
||
# don't allow disabling already-applied projs
|
||
with _events_off(fig.mne.proj_checkboxes):
|
||
for ix in np.where(applied)[0]:
|
||
if not new_state[ix]:
|
||
fig.mne.proj_checkboxes.set_active(ix)
|
||
new_state[applied] = True
|
||
# update the data if necessary
|
||
if not np.array_equal(on, new_state):
|
||
self.mne.projs_on = new_state
|
||
self._update_projector()
|
||
self._redraw()
|
||
|
||
def _toggle_epoch_histogram(self):
|
||
"""Show or hide peak-to-peak histogram of channel amplitudes."""
|
||
if self.mne.instance_type == "epochs":
|
||
if self.mne.fig_histogram is None:
|
||
self._create_epoch_histogram()
|
||
else:
|
||
plt.close(self.mne.fig_histogram)
|
||
|
||
def _toggle_bad_channel(self, idx):
|
||
"""Mark/unmark bad channels; `idx` is index of *visible* channels."""
|
||
color, pick, marked_bad = super()._toggle_bad_channel(idx)
|
||
|
||
# update sensor color (if in selection mode)
|
||
if self.mne.fig_selection is not None:
|
||
self._update_bad_sensors(pick, marked_bad)
|
||
# update vscroll color
|
||
vscroll_idx = (self.mne.ch_order == pick).nonzero()[0]
|
||
for _idx in vscroll_idx:
|
||
self.mne.ax_vscroll.patches[_idx].set_color(color)
|
||
# redraw
|
||
self._redraw()
|
||
|
||
def _toggle_bad_epoch(self, event):
|
||
"""Mark/unmark bad epochs."""
|
||
epoch_ix, color = super()._toggle_bad_epoch(event.xdata)
|
||
self.mne.ax_hscroll.patches[epoch_ix].set_color(color)
|
||
self._redraw(update_data=False)
|
||
|
||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||
# SCROLLBARS
|
||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||
|
||
def _update_zen_mode_offsets(self):
|
||
"""Compute difference between main axes edges and scrollbar edges."""
|
||
self.mne.fig_size_px = self._get_size_px()
|
||
self.mne.zen_w = (
|
||
self.mne.ax_vscroll.get_position().xmax
|
||
- self.mne.ax_main.get_position().xmax
|
||
)
|
||
self.mne.zen_h = (
|
||
self.mne.ax_main.get_position().ymin
|
||
- self.mne.ax_hscroll.get_position().ymin
|
||
)
|
||
|
||
def _toggle_scrollbars(self):
|
||
"""Show or hide scrollbars (A.K.A. zen mode)."""
|
||
self._update_zen_mode_offsets()
|
||
# grow/shrink main axes to take up space from (or make room for)
|
||
# scrollbars. We can't use ax.set_position() because axes are
|
||
# locatable, so we use subplots_adjust
|
||
should_show = not self.mne.scrollbars_visible
|
||
margins = {
|
||
side: getattr(self.subplotpars, side)
|
||
for side in ("left", "bottom", "right", "top")
|
||
}
|
||
# if should_show, bottom margin moves up; right margin moves left
|
||
margins["bottom"] += (1 if should_show else -1) * self.mne.zen_h
|
||
margins["right"] += (-1 if should_show else 1) * self.mne.zen_w
|
||
self.subplots_adjust(**margins)
|
||
# handle x-axis label
|
||
self.mne.zen_xlabel.set_visible(not should_show)
|
||
# show/hide other UI elements
|
||
for elem in ("ax_hscroll", "ax_vscroll", "ax_proj", "ax_help"):
|
||
if elem == "ax_vscroll" and self.mne.butterfly:
|
||
continue
|
||
# sometimes we don't have a proj button (ax_proj)
|
||
if getattr(self.mne, elem, None) is not None:
|
||
getattr(self.mne, elem).set_visible(should_show)
|
||
self.mne.scrollbars_visible = should_show
|
||
|
||
def _update_vscroll(self):
|
||
"""Update the vertical scrollbar (channel) selection indicator."""
|
||
self.mne.vsel_patch.set_xy((0, self.mne.ch_start))
|
||
self.mne.vsel_patch.set_height(self.mne.n_channels)
|
||
self._update_yaxis_labels()
|
||
|
||
def _update_hscroll(self):
|
||
"""Update the horizontal scrollbar (time) selection indicator."""
|
||
self.mne.hsel_patch.set_xy((self.mne.t_start, 0))
|
||
self.mne.hsel_patch.set_width(self.mne.duration)
|
||
|
||
def _check_update_hscroll_clicked(self, event):
|
||
"""Handle clicks on horizontal scrollbar."""
|
||
time = event.xdata - self.mne.duration / 2
|
||
max_time = (
|
||
self.mne.n_times / self.mne.info["sfreq"]
|
||
+ self.mne.first_time
|
||
- self.mne.duration
|
||
)
|
||
time = np.clip(time, self.mne.first_time, max_time)
|
||
if self.mne.is_epochs:
|
||
ix = np.searchsorted(self.mne.boundary_times[1:], time)
|
||
time = self.mne.boundary_times[ix]
|
||
if self.mne.t_start != time:
|
||
self.mne.t_start = time
|
||
self._update_hscroll()
|
||
return True
|
||
return False
|
||
|
||
def _check_update_vscroll_clicked(self, event):
|
||
"""Update vscroll patch on click, return True if location changed."""
|
||
new_ch_start = np.clip(
|
||
int(round(event.ydata - self.mne.n_channels / 2)),
|
||
0,
|
||
len(self.mne.ch_order) - self.mne.n_channels,
|
||
)
|
||
if self.mne.ch_start != new_ch_start:
|
||
self.mne.ch_start = new_ch_start
|
||
self._update_picks()
|
||
self._update_vscroll()
|
||
return True
|
||
return False
|
||
|
||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||
# SCALEBARS & AXIS LABELS
|
||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||
|
||
def _show_scalebars(self):
|
||
"""Add channel scale bars."""
|
||
for pi, pick in enumerate(self.mne.picks):
|
||
this_name = self.mne.ch_names[pick]
|
||
this_type = self.mne.ch_types[pick]
|
||
# TODO: Simplify this someday -- we have to duplicate the challenging
|
||
# logic of _draw_traces here
|
||
offset_ixs = (
|
||
self.mne.picks
|
||
if self.mne.butterfly and self.mne.ch_selections is None
|
||
else slice(None)
|
||
)
|
||
offset = self.mne.trace_offsets[offset_ixs][pi]
|
||
if (
|
||
this_type not in self.mne.scalebars
|
||
and this_type != "stim"
|
||
and this_type in self.mne.scalings
|
||
and this_type in getattr(self.mne, "units", {})
|
||
and this_type in getattr(self.mne, "unit_scalings", {})
|
||
and this_name not in self.mne.info["bads"]
|
||
and this_name not in self.mne.whitened_ch_names
|
||
):
|
||
x = (self.mne.times[0] + self.mne.first_time,) * 2
|
||
denom = 4 if self.mne.butterfly else 2
|
||
y = tuple(np.array([-1, 1]) / denom + offset)
|
||
self._draw_one_scalebar(x, y, this_type)
|
||
if self.mne.is_epochs:
|
||
x = (
|
||
self.mne.times[0],
|
||
self.mne.times[0] + self.mne.boundary_times[1] / 2,
|
||
)
|
||
y_value = self.mne.n_channels - 0.5
|
||
y = (y_value, y_value)
|
||
self._draw_one_scalebar(x, y, "time")
|
||
|
||
def _hide_scalebars(self):
|
||
"""Remove channel scale bars."""
|
||
for bar in self.mne.scalebars.values():
|
||
bar.remove()
|
||
for text in self.mne.scalebar_texts.values():
|
||
text.remove()
|
||
self.mne.scalebars = dict()
|
||
self.mne.scalebar_texts = dict()
|
||
|
||
def _toggle_scalebars(self, event):
|
||
"""Show/hide the scalebars."""
|
||
if self.mne.scalebars_visible:
|
||
self._hide_scalebars()
|
||
else:
|
||
self._update_picks()
|
||
self._show_scalebars()
|
||
# toggle
|
||
self.mne.scalebars_visible = not self.mne.scalebars_visible
|
||
self._redraw(update_data=False)
|
||
|
||
def _draw_one_scalebar(self, x, y, ch_type):
|
||
"""Draw a scalebar."""
|
||
from .utils import _simplify_float
|
||
|
||
color = "#AA3377" # purple
|
||
kwargs = dict(color=color, zorder=self.mne.zorder["scalebar"])
|
||
if ch_type == "time":
|
||
label = f"{self.mne.boundary_times[1] / 2:.2f} s"
|
||
text = self.mne.ax_main.text(
|
||
x[0] + 0.015,
|
||
y[1] - 0.05,
|
||
label,
|
||
va="bottom",
|
||
ha="left",
|
||
size="xx-small",
|
||
**kwargs,
|
||
)
|
||
else:
|
||
scaler = 1 if self.mne.butterfly else 2
|
||
inv_norm = (
|
||
scaler
|
||
* self.mne.scalings[ch_type]
|
||
* self.mne.unit_scalings[ch_type]
|
||
/ self.mne.scale_factor
|
||
)
|
||
label = f"{_simplify_float(inv_norm)} {self.mne.units[ch_type]} "
|
||
text = self.mne.ax_main.text(
|
||
x[1], y[1], label, va="baseline", ha="right", size="xx-small", **kwargs
|
||
)
|
||
bar = self.mne.ax_main.plot(x, y, lw=4, **kwargs)[0]
|
||
self.mne.scalebars[ch_type] = bar
|
||
self.mne.scalebar_texts[ch_type] = text
|
||
|
||
def _update_yaxis_labels(self):
|
||
"""Change the y-axis labels."""
|
||
if self.mne.butterfly and self.mne.fig_selection is not None:
|
||
exclude = ("Vertex", "Custom")
|
||
ticklabels = list(self.mne.ch_selections)
|
||
keep_mask = np.isin(ticklabels, exclude, invert=True)
|
||
ticklabels = [
|
||
t.replace("Left-", "L-").replace("Right-", "R-") for t in ticklabels
|
||
] # avoid having to rotate labels
|
||
ticklabels = np.array(ticklabels)[keep_mask]
|
||
elif self.mne.butterfly:
|
||
_, ixs, _ = np.intersect1d(
|
||
_DATA_CH_TYPES_ORDER_DEFAULT, self.mne.ch_types, return_indices=True
|
||
)
|
||
ixs.sort()
|
||
ticklabels = np.array(_DATA_CH_TYPES_ORDER_DEFAULT)[ixs]
|
||
else:
|
||
ticklabels = self.mne.ch_names[self.mne.picks]
|
||
texts = self.mne.ax_main.set_yticklabels(ticklabels, picker=True)
|
||
for text in texts:
|
||
sty = (
|
||
"italic" if text.get_text() in self.mne.whitened_ch_names else "normal"
|
||
)
|
||
text.set_style(sty)
|
||
|
||
def _xtick_formatter(self, x, pos=None, ax_type="main"):
|
||
"""Change the x-axis labels."""
|
||
tickdiff = np.diff(self.mne.ax_main.get_xticks())[0]
|
||
digits = np.ceil(-np.log10(tickdiff) + 1).astype(int)
|
||
# always show millisecond precision for vline text
|
||
if ax_type == "vline":
|
||
digits = 3
|
||
if self.mne.time_format == "float":
|
||
# round to integers when possible ('9.0' → '9')
|
||
if int(x) == x:
|
||
digits = None
|
||
if ax_type == "vline":
|
||
return f"{round(x, digits)} s"
|
||
return str(round(x, digits))
|
||
# format as timestamp
|
||
meas_date = self.mne.inst.info["meas_date"]
|
||
first_time = datetime.timedelta(seconds=self.mne.inst.first_time)
|
||
xtime = datetime.timedelta(seconds=x)
|
||
xdatetime = meas_date + first_time + xtime
|
||
xdtstr = xdatetime.strftime("%H:%M:%S")
|
||
if digits and ax_type != "hscroll" and int(xdatetime.microsecond):
|
||
xdtstr += f"{round(xdatetime.microsecond * 1e-6, digits)}"[1:]
|
||
return xdtstr
|
||
|
||
def _toggle_time_format(self):
|
||
if self.mne.time_format == "float":
|
||
self.mne.time_format = "clock"
|
||
x_axis_label = "Time (HH:MM:SS)"
|
||
else:
|
||
self.mne.time_format = "float"
|
||
x_axis_label = "Time (s)"
|
||
|
||
# Change x-axis label
|
||
for _ax in (self.mne.ax_main, self.mne.ax_hscroll):
|
||
_ax.set_xlabel(x_axis_label)
|
||
|
||
self._redraw(update_data=False, annotations=False)
|
||
|
||
# Update vline-text if displayed
|
||
if self.mne.vline is not None and self.mne.vline.get_visible():
|
||
self._show_vline(self.mne.vline.get_xdata())
|
||
|
||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||
# DATA TRACES
|
||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||
|
||
def _toggle_butterfly(self):
|
||
"""Enter or leave butterfly mode."""
|
||
self.mne.ax_vscroll.set_visible(self.mne.butterfly)
|
||
self.mne.butterfly = not self.mne.butterfly
|
||
self.mne.scale_factor *= 0.5 if self.mne.butterfly else 2.0
|
||
self._update_picks()
|
||
self._update_trace_offsets()
|
||
self._redraw(annotations=True)
|
||
if self.mne.fig_selection is not None:
|
||
self.mne.fig_selection._style_radio_buttons_butterfly()
|
||
|
||
def _update_trace_offsets(self):
|
||
"""Compute viewport height and adjust offsets."""
|
||
# simultaneous selection and butterfly modes
|
||
if self.mne.butterfly and self.mne.ch_selections is not None:
|
||
self._update_picks()
|
||
selections_dict = self._make_butterfly_selections_dict()
|
||
n_offsets = len(selections_dict)
|
||
sel_order = list(selections_dict)
|
||
offsets = np.array([])
|
||
for pick in self.mne.picks:
|
||
for sel in sel_order:
|
||
if pick in selections_dict[sel]:
|
||
offsets = np.append(offsets, sel_order.index(sel))
|
||
# butterfly only
|
||
elif self.mne.butterfly:
|
||
unique_ch_types = set(self.mne.ch_types)
|
||
n_offsets = len(unique_ch_types)
|
||
ch_type_order = [
|
||
_type
|
||
for _type in _DATA_CH_TYPES_ORDER_DEFAULT
|
||
if _type in unique_ch_types
|
||
]
|
||
offsets = np.array(
|
||
[ch_type_order.index(ch_type) for ch_type in self.mne.ch_types]
|
||
)
|
||
# normal mode
|
||
else:
|
||
n_offsets = self.mne.n_channels
|
||
offsets = np.arange(n_offsets, dtype=float)
|
||
# update ylim, ticks, vertline, and scrollbar patch
|
||
ylim = (n_offsets - 0.5, -0.5) # inverted y axis → new chs at bottom
|
||
self.mne.ax_main.set_ylim(ylim)
|
||
self.mne.ax_main.set_yticks(np.unique(offsets))
|
||
self.mne.vsel_patch.set_height(self.mne.n_channels)
|
||
# store new offsets, update axis labels
|
||
self.mne.trace_offsets = offsets
|
||
self._update_yaxis_labels()
|
||
|
||
def _draw_traces(self):
|
||
"""Draw (or redraw) the channel data."""
|
||
from matplotlib.colors import to_rgba_array
|
||
from matplotlib.patches import Rectangle
|
||
|
||
# clear scalebars
|
||
if self.mne.scalebars_visible:
|
||
self._hide_scalebars()
|
||
# get info about currently visible channels
|
||
picks = self.mne.picks
|
||
ch_names = self.mne.ch_names[picks]
|
||
ch_types = self.mne.ch_types[picks]
|
||
offset_ixs = (
|
||
picks
|
||
if self.mne.butterfly and self.mne.ch_selections is None
|
||
else slice(None)
|
||
)
|
||
offsets = self.mne.trace_offsets[offset_ixs]
|
||
bad_bool = np.isin(ch_names, self.mne.info["bads"])
|
||
# colors
|
||
good_ch_colors = [self.mne.ch_color_dict[_type] for _type in ch_types]
|
||
ch_colors = to_rgba_array(
|
||
[
|
||
self.mne.ch_color_bad if _bad else _color
|
||
for _bad, _color in zip(bad_bool, good_ch_colors)
|
||
]
|
||
)
|
||
self.mne.ch_colors = np.array(good_ch_colors) # use for unmarking bads
|
||
labels = self.mne.ax_main.yaxis.get_ticklabels()
|
||
if self.mne.butterfly:
|
||
for label in labels:
|
||
label.set_color(self.mne.fgcolor)
|
||
else:
|
||
for label, color in zip(labels, ch_colors):
|
||
label.set_color(color)
|
||
# decim
|
||
decim = np.ones_like(picks)
|
||
data_picks_mask = np.isin(picks, self.mne.picks_data)
|
||
decim[data_picks_mask] = self.mne.decim
|
||
# decim can vary by channel type, so compute different `times` vectors
|
||
decim_times = {
|
||
decim_value: self.mne.times[::decim_value] + self.mne.first_time
|
||
for decim_value in set(decim)
|
||
}
|
||
# add more traces if needed
|
||
n_picks = len(picks)
|
||
if n_picks > len(self.mne.traces):
|
||
n_new_chs = n_picks - len(self.mne.traces)
|
||
new_traces = self.mne.ax_main.plot(
|
||
np.full((1, n_new_chs), np.nan), **self.mne.trace_kwargs
|
||
)
|
||
self.mne.traces.extend(new_traces)
|
||
# remove extra traces if needed
|
||
extra_traces = self.mne.traces[n_picks:]
|
||
for trace in extra_traces:
|
||
trace.remove()
|
||
self.mne.traces = self.mne.traces[:n_picks]
|
||
|
||
# check for bad epochs
|
||
time_range = (self.mne.times + self.mne.first_time)[[0, -1]]
|
||
if self.mne.instance_type == "epochs":
|
||
epoch_ix = np.searchsorted(self.mne.boundary_times, time_range)
|
||
epoch_ix = np.arange(epoch_ix[0], epoch_ix[1])
|
||
epoch_nums = self.mne.inst.selection[epoch_ix[0] : epoch_ix[-1] + 1]
|
||
(visible_bad_epoch_ix,) = np.isin(epoch_nums, self.mne.bad_epochs).nonzero()
|
||
while len(self.mne.epoch_traces):
|
||
self.mne.epoch_traces.pop(-1).remove()
|
||
# handle custom epoch colors (for autoreject integration)
|
||
if self.mne.epoch_colors is None:
|
||
# shape: n_traces × RGBA → n_traces × n_epochs × RGBA
|
||
custom_colors = np.tile(
|
||
ch_colors[:, None, :], (1, self.mne.n_epochs, 1)
|
||
)
|
||
else:
|
||
custom_colors = np.empty((len(self.mne.picks), self.mne.n_epochs, 4))
|
||
for ii, _epoch_ix in enumerate(epoch_ix):
|
||
this_colors = self.mne.epoch_colors[_epoch_ix]
|
||
custom_colors[:, ii] = to_rgba_array(
|
||
[this_colors[_ch] for _ch in picks]
|
||
)
|
||
# override custom color on bad epochs
|
||
for _ix in visible_bad_epoch_ix:
|
||
_cols = np.array(
|
||
[self.mne.epoch_color_bad, self.mne.ch_color_bad], dtype=object
|
||
)[bad_bool.astype(int)]
|
||
custom_colors[:, _ix] = to_rgba_array(_cols)
|
||
|
||
# update traces
|
||
ylim = self.mne.ax_main.get_ylim()
|
||
for ii, line in enumerate(self.mne.traces):
|
||
this_name = ch_names[ii]
|
||
this_type = ch_types[ii]
|
||
this_offset = offsets[ii]
|
||
this_times = decim_times[decim[ii]]
|
||
this_data = this_offset - self.mne.data[ii] * self.mne.scale_factor
|
||
this_data = this_data[..., :: decim[ii]]
|
||
# clip
|
||
if self.mne.clipping == "clamp":
|
||
this_data = np.clip(this_data, -0.5, 0.5)
|
||
elif self.mne.clipping is not None:
|
||
clip = self.mne.clipping * (0.2 if self.mne.butterfly else 1)
|
||
bottom = max(this_offset - clip, ylim[1])
|
||
height = min(2 * clip, ylim[0] - bottom)
|
||
rect = Rectangle(
|
||
xy=np.array([time_range[0], bottom]),
|
||
width=time_range[1] - time_range[0],
|
||
height=height,
|
||
transform=self.mne.ax_main.transData,
|
||
)
|
||
line.set_clip_path(rect)
|
||
# prep z order
|
||
is_bad_ch = this_name in self.mne.info["bads"]
|
||
this_z = self.mne.zorder["bads" if is_bad_ch else "data"]
|
||
if self.mne.butterfly and not is_bad_ch:
|
||
this_z = self.mne.zorder.get(this_type, this_z)
|
||
# plot each trace multiple times to get the desired epoch coloring.
|
||
# use masked arrays to plot discontinuous epochs that have the same
|
||
# color in a single plot() call.
|
||
if self.mne.instance_type == "epochs":
|
||
this_colors = custom_colors[ii]
|
||
for cix, color in enumerate(np.unique(this_colors, axis=0)):
|
||
bool_ixs = (this_colors == color).all(axis=1)
|
||
mask = np.zeros_like(this_times, dtype=bool)
|
||
_starts = self.mne.boundary_times[epoch_ix][bool_ixs]
|
||
_stops = self.mne.boundary_times[epoch_ix + 1][bool_ixs]
|
||
for _start, _stop in zip(_starts, _stops):
|
||
_mask = np.logical_and(_start < this_times, this_times <= _stop)
|
||
mask = mask | _mask
|
||
_times = np.ma.masked_array(this_times, mask=~mask)
|
||
# always use the existing traces first
|
||
if cix == 0:
|
||
line.set_xdata(_times)
|
||
line.set_ydata(this_data)
|
||
line.set_color(color)
|
||
line.set_zorder(this_z)
|
||
else: # make new traces as needed
|
||
_trace = self.mne.ax_main.plot(
|
||
_times,
|
||
this_data,
|
||
color=color,
|
||
zorder=this_z,
|
||
**self.mne.trace_kwargs,
|
||
)
|
||
self.mne.epoch_traces.extend(_trace)
|
||
else:
|
||
line.set_xdata(this_times)
|
||
line.set_ydata(this_data)
|
||
line.set_color(ch_colors[ii])
|
||
line.set_zorder(this_z)
|
||
# update xlim
|
||
self.mne.ax_main.set_xlim(*time_range)
|
||
# draw scalebars maybe
|
||
if self.mne.scalebars_visible:
|
||
self._show_scalebars()
|
||
# redraw event lines
|
||
if self.mne.event_times is not None:
|
||
self._draw_event_lines()
|
||
|
||
def _redraw(self, update_data=True, annotations=False):
|
||
"""Redraw (convenience method for frequently grouped actions)."""
|
||
super()._redraw(update_data, annotations)
|
||
if self.mne.vline_visible and self.mne.is_epochs:
|
||
# prevent flickering
|
||
_ = self._recompute_epochs_vlines(None)
|
||
self.canvas.draw_idle()
|
||
|
||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||
# EVENT LINES AND MARKER LINES
|
||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||
|
||
def _draw_event_lines(self):
|
||
"""Draw the event lines and their labels."""
|
||
from matplotlib.collections import LineCollection
|
||
from matplotlib.colors import to_rgba_array
|
||
|
||
if self.mne.event_nums is not None:
|
||
mask = np.logical_and(
|
||
self.mne.event_times >= self.mne.times[0],
|
||
self.mne.event_times <= self.mne.times[-1],
|
||
)
|
||
this_event_times = self.mne.event_times[mask]
|
||
this_event_nums = self.mne.event_nums[mask]
|
||
n_visible_events = len(this_event_times)
|
||
colors = to_rgba_array(
|
||
[self.mne.event_color_dict[n] for n in this_event_nums]
|
||
)
|
||
# create event lines
|
||
ylim = self.mne.ax_main.get_ylim()
|
||
xs = np.repeat(this_event_times, 2)
|
||
ys = np.tile(ylim, n_visible_events)
|
||
segs = np.vstack([xs, ys]).T.reshape(n_visible_events, 2, 2)
|
||
event_lines = LineCollection(
|
||
segs, linewidths=0.5, colors=colors, zorder=self.mne.zorder["events"]
|
||
)
|
||
self.mne.ax_main.add_collection(event_lines)
|
||
self.mne.event_lines = event_lines
|
||
# create event labels
|
||
while len(self.mne.event_texts):
|
||
self.mne.event_texts.pop().remove()
|
||
for _t, _n, _c in zip(this_event_times, this_event_nums, colors):
|
||
label = self.mne.event_id_rev.get(_n, _n)
|
||
this_text = self.mne.ax_main.annotate(
|
||
label,
|
||
(_t, ylim[1]),
|
||
ha="center",
|
||
va="baseline",
|
||
color=self.mne.fgcolor,
|
||
xytext=(0, 2),
|
||
textcoords="offset points",
|
||
fontsize=8,
|
||
)
|
||
self.mne.event_texts.append(this_text)
|
||
|
||
def _recompute_epochs_vlines(self, xdata):
|
||
"""Recompute vline x-coords for epochs plots (after scrolling, etc)."""
|
||
# special case: changed view duration w/ "home" or "end" key
|
||
# (no click event, hence no xdata)
|
||
if xdata is None:
|
||
xdata = np.array(self.mne.vline.get_segments())[0, 0, 0]
|
||
# compute the (continuous) times for the lines on each epoch
|
||
epoch_dur = np.diff(self.mne.boundary_times[:2])[0]
|
||
rel_time = xdata % epoch_dur
|
||
abs_time = self.mne.times[0]
|
||
xs = np.arange(self.mne.n_epochs) * epoch_dur + abs_time + rel_time
|
||
segs = np.array(self.mne.vline.get_segments())
|
||
# recreate segs from scratch in case view duration changed
|
||
# (i.e., handle case when n_segments != n_epochs)
|
||
segs = np.tile([[0.0], [1.0]], (len(xs), 1, 2)) # y values
|
||
segs[..., 0] = np.tile(xs[:, None], 2) # x values
|
||
self.mne.vline.set_segments(segs)
|
||
return rel_time
|
||
|
||
def _show_vline(self, xdata):
|
||
"""Show the vertical line(s)."""
|
||
if self.mne.is_epochs:
|
||
# convert xdata to be epoch-relative (for the text)
|
||
rel_time = self._recompute_epochs_vlines(xdata)
|
||
xdata = rel_time + self.mne.inst.times[0]
|
||
else:
|
||
self.mne.vline.set_xdata([xdata])
|
||
self.mne.vline_hscroll.set_xdata([xdata])
|
||
text = self._xtick_formatter(xdata, ax_type="vline")[:12]
|
||
self.mne.vline_text.set_text(text)
|
||
self._toggle_vline(True)
|
||
|
||
def _toggle_vline(self, visible):
|
||
"""Show or hide the vertical line(s)."""
|
||
for artist in (self.mne.vline, self.mne.vline_hscroll, self.mne.vline_text):
|
||
if artist is not None:
|
||
artist.set_visible(visible)
|
||
self.draw_artist(artist)
|
||
self.mne.vline_visible = visible
|
||
self.canvas.draw_idle()
|
||
|
||
# workaround: plt.close() doesn't spawn close_event on Agg backend, this method
|
||
# can be removed once the _close_event in fixes.py is removed
|
||
def _close_event(self, fig=None):
|
||
"""Force calling of the MPL figure close event."""
|
||
fig = fig or self
|
||
_close_event(fig)
|
||
|
||
def _fake_keypress(self, key, fig=None):
|
||
fig = fig or self
|
||
_fake_keypress(fig, key)
|
||
|
||
def _fake_click(
|
||
self,
|
||
point,
|
||
add_points=None,
|
||
fig=None,
|
||
ax=None,
|
||
xform="ax",
|
||
button=1,
|
||
kind="press",
|
||
):
|
||
"""Fake a click at a relative point within axes."""
|
||
fig = fig or self
|
||
ax = ax or self.mne.ax_main
|
||
if kind == "drag" and add_points is not None:
|
||
_fake_click(
|
||
fig=fig, ax=ax, point=point, xform=xform, button=button, kind="press"
|
||
)
|
||
for apoint in add_points:
|
||
_fake_click(
|
||
fig=fig,
|
||
ax=ax,
|
||
point=apoint,
|
||
xform=xform,
|
||
button=button,
|
||
kind="motion",
|
||
)
|
||
_fake_click(
|
||
fig=fig,
|
||
ax=ax,
|
||
point=add_points[-1],
|
||
xform=xform,
|
||
button=button,
|
||
kind="release",
|
||
)
|
||
else:
|
||
_fake_click(
|
||
fig=fig, ax=ax, point=point, xform=xform, button=button, kind=kind
|
||
)
|
||
|
||
def _fake_scroll(self, x, y, step, fig=None):
|
||
fig = fig or self
|
||
_fake_scroll(fig, x, y, step)
|
||
|
||
def _click_ch_name(self, ch_index, button):
|
||
_click_ch_name(self, ch_index, button)
|
||
|
||
def _resize_by_factor(self, factor=None):
|
||
size = self.canvas.manager.canvas.get_width_height()
|
||
if isinstance(factor, tuple):
|
||
size = int(size[0] * factor[0], size[1] * factor[1])
|
||
else:
|
||
size = [int(x * factor) for x in size]
|
||
self.canvas.manager.resize(*size)
|
||
|
||
def _get_ticklabels(self, orientation):
|
||
if orientation == "x":
|
||
labels = self.mne.ax_main.get_xticklabels(minor=self.mne.is_epochs)
|
||
elif orientation == "y":
|
||
labels = self.mne.ax_main.get_yticklabels()
|
||
label_texts = [lb.get_text() for lb in labels]
|
||
|
||
return label_texts
|
||
|
||
def _get_scale_bar_texts(self):
|
||
texts = tuple(t.get_text().strip() for t in self.mne.ax_main.texts)
|
||
# First text is empty because of vline
|
||
texts = texts[1:]
|
||
|
||
return texts
|
||
|
||
|
||
class MNELineFigure(MNEFigure):
|
||
"""Interactive figure for non-scrolling line plots."""
|
||
|
||
def __init__(self, inst, n_axes, figsize, *, layout="constrained", **kwargs):
|
||
super().__init__(
|
||
figsize=figsize,
|
||
inst=inst,
|
||
layout=layout,
|
||
sharex=True,
|
||
**kwargs,
|
||
)
|
||
for ix in range(n_axes):
|
||
self.add_subplot(n_axes, 1, ix + 1)
|
||
|
||
|
||
def _close_all():
|
||
"""Close all figures (only used in our tests)."""
|
||
plt.close("all")
|
||
|
||
|
||
def _get_n_figs():
|
||
return len(plt.get_fignums())
|
||
|
||
|
||
def _figure(toolbar=True, FigureClass=MNEFigure, **kwargs):
|
||
"""Instantiate a new figure."""
|
||
from matplotlib import rc_context
|
||
|
||
title = kwargs.pop("window_title", None) # extract title before init
|
||
if "layout" not in kwargs:
|
||
kwargs["layout"] = "constrained"
|
||
rc = dict() if toolbar else dict(toolbar="none")
|
||
with rc_context(rc=rc):
|
||
fig = plt.figure(FigureClass=FigureClass, **kwargs)
|
||
# BACKEND defined globally at the top of this file
|
||
fig.mne.backend = BACKEND
|
||
if title is not None:
|
||
_set_window_title(fig, title)
|
||
# TODO: for some reason for topomaps->_prepare_trellis the layout=constrained does
|
||
# not work the first time (maybe toolbar=False?)
|
||
if kwargs.get("layout") == "constrained":
|
||
fig.set_layout_engine("constrained")
|
||
|
||
# add event callbacks
|
||
fig._add_default_callbacks()
|
||
return fig
|
||
|
||
|
||
def _line_figure(inst, axes=None, picks=None, **kwargs):
|
||
"""Instantiate a new line figure."""
|
||
from matplotlib.axes import Axes
|
||
|
||
# if picks is None, only show data channels
|
||
allowed_ch_types = _DATA_CH_TYPES_SPLIT if picks is None else _VALID_CHANNEL_TYPES
|
||
# figure out expected number of axes
|
||
ch_types = np.array(inst.get_channel_types())
|
||
if picks is not None:
|
||
ch_types = ch_types[picks]
|
||
n_axes = len(np.intersect1d(ch_types, allowed_ch_types))
|
||
# handle user-provided axes
|
||
if axes is not None:
|
||
if isinstance(axes, Axes):
|
||
axes = [axes]
|
||
_validate_if_list_of_axes(axes, n_axes)
|
||
fig = axes[0].get_figure()
|
||
else:
|
||
figsize = kwargs.pop("figsize", (10, 2.5 * n_axes + 1))
|
||
fig = _figure(
|
||
inst=inst,
|
||
toolbar=True,
|
||
FigureClass=MNELineFigure,
|
||
figsize=figsize,
|
||
n_axes=n_axes,
|
||
**kwargs,
|
||
)
|
||
fig.mne.fig_size_px = fig._get_size_px() # can't do in __init__
|
||
axes = fig.axes
|
||
return fig, axes
|
||
|
||
|
||
def _split_picks_by_type(inst, picks, units, scalings, titles):
|
||
"""Separate picks, units, etc, for plotting on separate subplots."""
|
||
picks_list = list()
|
||
units_list = list()
|
||
scalings_list = list()
|
||
titles_list = list()
|
||
# if picks is None, only show data channels
|
||
allowed_ch_types = _DATA_CH_TYPES_SPLIT if picks is None else _VALID_CHANNEL_TYPES
|
||
for ch_type in allowed_ch_types:
|
||
pick_kwargs = dict(meg=False, ref_meg=False, exclude=[])
|
||
if ch_type in ("mag", "grad"):
|
||
pick_kwargs["meg"] = ch_type
|
||
elif ch_type in _FNIRS_CH_TYPES_SPLIT:
|
||
pick_kwargs["fnirs"] = ch_type
|
||
elif ch_type in _EYETRACK_CH_TYPES_SPLIT:
|
||
pick_kwargs["eyetrack"] = ch_type
|
||
else:
|
||
pick_kwargs[ch_type] = True
|
||
these_picks = pick_types(inst.info, **pick_kwargs)
|
||
these_picks = np.intersect1d(these_picks, picks)
|
||
if len(these_picks) > 0:
|
||
picks_list.append(these_picks)
|
||
units_list.append(units[ch_type])
|
||
scalings_list.append(scalings[ch_type])
|
||
titles_list.append(titles[ch_type])
|
||
if len(picks_list) == 0:
|
||
raise RuntimeError("No data channels found")
|
||
return picks_list, units_list, scalings_list, titles_list
|
||
|
||
|
||
def _calc_new_margins(fig, old_width, old_height, new_width, new_height):
|
||
"""Compute new figure-relative values to maintain fixed-size margins."""
|
||
new_margins = dict()
|
||
for side in ("left", "right", "bottom", "top"):
|
||
ratio = (
|
||
(old_width / new_width)
|
||
if side in ("left", "right")
|
||
else (old_height / new_height)
|
||
)
|
||
rel_dim = getattr(fig.subplotpars, side)
|
||
if side in ("right", "top"):
|
||
new_margins[side] = 1 - ratio * (1 - rel_dim)
|
||
else:
|
||
new_margins[side] = ratio * rel_dim
|
||
# gh-8304: don't allow resizing too small
|
||
if (
|
||
new_margins["bottom"] < new_margins["top"]
|
||
and new_margins["left"] < new_margins["right"]
|
||
):
|
||
return new_margins
|
||
|
||
|
||
@contextmanager
|
||
def _patched_canvas(fig):
|
||
old_canvas = fig.canvas
|
||
if fig.canvas is None: # XXX old MPL (at least 3.0.3) does this for Agg
|
||
fig.canvas = Bunch(mpl_connect=lambda event, callback: None)
|
||
try:
|
||
yield
|
||
finally:
|
||
fig.canvas = old_canvas
|
||
|
||
|
||
def _init_browser(**kwargs):
|
||
"""Instantiate a new MNE browse-style figure."""
|
||
from mne.io import BaseRaw
|
||
|
||
fig = _figure(toolbar=False, FigureClass=MNEBrowseFigure, layout=None, **kwargs)
|
||
|
||
# splash is ignored (maybe we could do it for mpl if we get_backend() and
|
||
# check if it's Qt... but seems overkill)
|
||
|
||
# initialize zen mode
|
||
# (can't do in __init__ due to get_position() calls)
|
||
fig.canvas.draw()
|
||
fig._update_zen_mode_offsets()
|
||
fig._resize(None) # needed for MPL
|
||
|
||
# if scrollbars are supposed to start hidden,
|
||
# set to True and then toggle
|
||
if not fig.mne.scrollbars_visible:
|
||
fig.mne.scrollbars_visible = True
|
||
fig._toggle_scrollbars()
|
||
|
||
# Initialize parts of the plot
|
||
is_ica = fig.mne.instance_type == "ica"
|
||
|
||
if not is_ica:
|
||
# make channel selection dialog,
|
||
# if requested (doesn't work well in init)
|
||
if fig.mne.group_by in ("selection", "position"):
|
||
fig._create_selection_fig()
|
||
|
||
# start with projectors dialog open, if requested
|
||
if getattr(fig.mne, "show_options", False):
|
||
fig._toggle_proj_fig()
|
||
|
||
# update data, and plot
|
||
fig._update_trace_offsets()
|
||
fig._redraw(update_data=True, annotations=False)
|
||
|
||
if isinstance(fig.mne.inst, BaseRaw):
|
||
fig._setup_annotation_colors()
|
||
fig._draw_annotations()
|
||
|
||
return fig
|
||
|
||
|
||
def _get_check_kwargs(labels=None):
|
||
check_kwargs = dict()
|
||
if not _OLD_BUTTONS:
|
||
check_kwargs.update(
|
||
check_props=dict(s=144, clip_on=False),
|
||
frame_props=dict(s=144, clip_on=False),
|
||
)
|
||
if labels is not None:
|
||
textcolor = list()
|
||
checkcolor = list()
|
||
for label in labels:
|
||
if label.endswith("(already applied)"):
|
||
textcolor.append("0.5")
|
||
checkcolor.append("0.7")
|
||
else:
|
||
textcolor.append("k")
|
||
checkcolor.append("k")
|
||
check_kwargs["check_props"].update(facecolor=checkcolor, linewidth=1)
|
||
check_kwargs["frame_props"].update(edgecolor=checkcolor, linewidth=1)
|
||
check_kwargs["label_props"] = dict(color=textcolor)
|
||
return check_kwargs
|