253 lines
9.5 KiB
Python
253 lines
9.5 KiB
Python
"""Functions for plotting projectors."""
|
|
|
|
# Authors: The MNE-Python contributors.
|
|
# License: BSD-3-Clause
|
|
# Copyright the MNE-Python contributors.
|
|
|
|
from copy import deepcopy
|
|
|
|
import numpy as np
|
|
|
|
from .._fiff.pick import _picks_to_idx
|
|
from ..defaults import DEFAULTS
|
|
from ..utils import _pl, _validate_type, verbose, warn
|
|
from .evoked import _plot_evoked
|
|
from .topomap import _plot_projs_topomap
|
|
from .utils import _check_type_projs, plt_show
|
|
|
|
|
|
@verbose
|
|
def plot_projs_joint(
|
|
projs, evoked, picks_trace=None, *, topomap_kwargs=None, show=True, verbose=None
|
|
):
|
|
"""Plot projectors and evoked jointly.
|
|
|
|
Parameters
|
|
----------
|
|
projs : list of Projection
|
|
The projectors to plot.
|
|
evoked : instance of Evoked
|
|
The data to plot. Typically this is the evoked instance created from
|
|
averaging the epochs used to create the projection.
|
|
%(picks_plot_projs_joint_trace)s
|
|
topomap_kwargs : dict | None
|
|
Keyword arguments to pass to :func:`mne.viz.plot_projs_topomap`.
|
|
%(show)s
|
|
%(verbose)s
|
|
|
|
Returns
|
|
-------
|
|
fig : instance of matplotlib Figure
|
|
The figure.
|
|
|
|
Notes
|
|
-----
|
|
This function creates a figure with three columns:
|
|
|
|
1. The left shows the evoked data traces before (black) and after (green)
|
|
projection.
|
|
2. The center shows the topomaps associated with each of the projectors.
|
|
3. The right again shows the data traces (black), but this time with:
|
|
|
|
1. The data projected onto each projector with a single normalization
|
|
factor (solid lines). This is useful for seeing the relative power
|
|
in each projection vector.
|
|
2. The data projected onto each projector with individual normalization
|
|
factors (dashed lines). This is useful for visualizing each time
|
|
course regardless of its power.
|
|
3. Additional data traces from ``picks_trace`` (solid yellow lines).
|
|
This is useful for visualizing the "ground truth" of the time
|
|
course, e.g. the measured EOG or ECG channel time courses.
|
|
|
|
.. versionadded:: 1.1
|
|
"""
|
|
import matplotlib.pyplot as plt
|
|
|
|
from ..evoked import Evoked
|
|
|
|
_validate_type(evoked, Evoked, "evoked")
|
|
_validate_type(topomap_kwargs, (None, dict), "topomap_kwargs")
|
|
projs = _check_type_projs(projs)
|
|
topomap_kwargs = dict() if topomap_kwargs is None else topomap_kwargs
|
|
if picks_trace is not None:
|
|
picks_trace = _picks_to_idx(evoked.info, picks_trace, allow_empty=False)
|
|
info = evoked.info
|
|
ch_types = evoked.get_channel_types(unique=True, only_data_chs=True)
|
|
proj_by_type = dict() # will be set up like an enumerate key->[pi, proj]
|
|
ch_names_by_type = dict()
|
|
used = np.zeros(len(projs), int)
|
|
for ch_type in ch_types:
|
|
these_picks = _picks_to_idx(info, ch_type, allow_empty=True)
|
|
these_chs = [evoked.ch_names[pick] for pick in these_picks]
|
|
ch_names_by_type[ch_type] = these_chs
|
|
for pi, proj in enumerate(projs):
|
|
if not set(these_chs).intersection(proj["data"]["col_names"]):
|
|
continue
|
|
if ch_type not in proj_by_type:
|
|
proj_by_type[ch_type] = list()
|
|
proj_by_type[ch_type].append([pi, deepcopy(proj)])
|
|
used[pi] += 1
|
|
missing = (~used.astype(bool)).sum()
|
|
if missing:
|
|
warn(
|
|
f"{missing} projector{_pl(missing)} had no channel names "
|
|
"present in epochs"
|
|
)
|
|
del projs
|
|
ch_types = list(proj_by_type) # reduce to number we actually need
|
|
# room for legend
|
|
max_proj_per_type = max(len(x) for x in proj_by_type.values())
|
|
cs_trace = 3
|
|
cs_topo = 2
|
|
n_col = max_proj_per_type * cs_topo + 2 * cs_trace
|
|
n_row = len(ch_types)
|
|
shape = (n_row, n_col)
|
|
fig = plt.figure(
|
|
figsize=(n_col * 1.1 + 0.5, n_row * 1.8 + 0.5), layout="constrained"
|
|
)
|
|
ri = 0
|
|
# pick some sufficiently distinct colors (6 per proj type, e.g., ECG,
|
|
# should be enough hopefully!)
|
|
# https://personal.sron.nl/~pault/data/colourschemes.pdf
|
|
# "Vibrant" color scheme
|
|
proj_colors = [
|
|
"#CC3311", # red
|
|
"#009988", # teal
|
|
"#0077BB", # blue
|
|
"#EE3377", # magenta
|
|
"#EE7733", # orange
|
|
"#33BBEE", # cyan
|
|
]
|
|
trace_color = "#CCBB44" # yellow
|
|
after_color, after_name = "#228833", "green"
|
|
type_titles = DEFAULTS["titles"]
|
|
last_ax = [None] * 2
|
|
first_ax = dict()
|
|
pe_kwargs = dict(show=False, draw=False)
|
|
for ch_type, these_projs in proj_by_type.items():
|
|
these_idxs, these_projs = zip(*these_projs)
|
|
ch_names = ch_names_by_type[ch_type]
|
|
idx = np.where(
|
|
[np.isin(ch_names, proj["data"]["col_names"]).all() for proj in these_projs]
|
|
)[0]
|
|
used[idx] += 1
|
|
count = len(these_projs)
|
|
for proj in these_projs:
|
|
sub_idx = [proj["data"]["col_names"].index(name) for name in ch_names]
|
|
proj["data"]["data"] = proj["data"]["data"][:, sub_idx]
|
|
proj["data"]["col_names"] = ch_names
|
|
ba_ax = plt.subplot2grid(shape, (ri, 0), colspan=cs_trace, fig=fig)
|
|
topo_axes = [
|
|
plt.subplot2grid(
|
|
shape, (ri, ci * cs_topo + cs_trace), colspan=cs_topo, fig=fig
|
|
)
|
|
for ci in range(count)
|
|
]
|
|
tr_ax = plt.subplot2grid(
|
|
shape, (ri, n_col - cs_trace), colspan=cs_trace, fig=fig
|
|
)
|
|
# topomaps
|
|
_plot_projs_topomap(these_projs, info=info, axes=topo_axes, **topomap_kwargs)
|
|
for idx, proj, ax_ in zip(these_idxs, these_projs, topo_axes):
|
|
ax_.set_title("") # could use proj['desc'] but it's long
|
|
ax_.set_xlabel(f"projs[{idx}]", fontsize="small")
|
|
unit = DEFAULTS["units"][ch_type]
|
|
# traces
|
|
this_evoked = evoked.copy().pick(ch_names)
|
|
p = np.concatenate([p["data"]["data"] for p in these_projs])
|
|
assert p.shape == (len(these_projs), len(this_evoked.data))
|
|
traces = np.dot(p, this_evoked.data)
|
|
traces *= np.sign(np.mean(np.dot(this_evoked.data, traces.T), 0))[:, np.newaxis]
|
|
if picks_trace is not None:
|
|
ch_traces = evoked.data[picks_trace]
|
|
ch_traces -= np.mean(ch_traces, axis=1, keepdims=True)
|
|
ch_traces /= np.abs(ch_traces).max()
|
|
_plot_evoked(
|
|
this_evoked, picks="all", axes=[tr_ax], **pe_kwargs, spatial_colors=False
|
|
)
|
|
for line in tr_ax.lines:
|
|
line.set(lw=0.5, zorder=3)
|
|
for t in list(tr_ax.texts):
|
|
t.remove()
|
|
scale = 0.8 * np.abs(tr_ax.get_ylim()).max()
|
|
hs, labels = list(), list()
|
|
traces /= np.abs(traces).max() # uniformly scaled
|
|
for ti, trace in enumerate(traces):
|
|
hs.append(
|
|
tr_ax.plot(
|
|
this_evoked.times,
|
|
trace * scale,
|
|
color=proj_colors[ti % len(proj_colors)],
|
|
zorder=5,
|
|
)[0]
|
|
)
|
|
labels.append(f"projs[{these_idxs[ti]}]")
|
|
traces /= np.abs(traces).max(1, keepdims=True) # independently
|
|
for ti, trace in enumerate(traces):
|
|
tr_ax.plot(
|
|
this_evoked.times,
|
|
trace * scale,
|
|
color=proj_colors[ti % len(proj_colors)],
|
|
zorder=3.5,
|
|
ls="--",
|
|
lw=1.0,
|
|
alpha=0.75,
|
|
)
|
|
if picks_trace is not None:
|
|
trace_ch = [evoked.ch_names[pick] for pick in picks_trace]
|
|
if len(picks_trace) == 1:
|
|
trace_ch = trace_ch[0]
|
|
hs.append(
|
|
tr_ax.plot(
|
|
this_evoked.times,
|
|
ch_traces.T * scale,
|
|
color=trace_color,
|
|
lw=3,
|
|
zorder=4,
|
|
alpha=0.75,
|
|
)[0]
|
|
)
|
|
labels.append(str(trace_ch))
|
|
tr_ax.set(title="", xlabel="", ylabel="")
|
|
# This will steal space from the subplots in a constrained layout
|
|
# https://matplotlib.org/3.5.0/tutorials/intermediate/constrainedlayout_guide.html#legends # noqa: E501
|
|
tr_ax.legend(
|
|
hs,
|
|
labels,
|
|
loc="center left",
|
|
borderaxespad=0.05,
|
|
bbox_to_anchor=[1.05, 0.5],
|
|
)
|
|
last_ax[1] = tr_ax
|
|
key = "Projected time course"
|
|
if key not in first_ax:
|
|
first_ax[key] = tr_ax
|
|
# Before and after traces
|
|
_plot_evoked(this_evoked, picks="all", axes=[ba_ax], **pe_kwargs)
|
|
for line in ba_ax.lines:
|
|
line.set(lw=0.5, zorder=3)
|
|
loff = len(ba_ax.lines)
|
|
this_proj_evoked = this_evoked.copy().add_proj(these_projs)
|
|
# with meg='combined' any existing mag projectors (those already part
|
|
# of evoked before we add_proj above) will have greatly
|
|
# reduced power, so we ignore the warning about this issue
|
|
this_proj_evoked.apply_proj(verbose="error")
|
|
_plot_evoked(this_proj_evoked, picks="all", axes=[ba_ax], **pe_kwargs)
|
|
for line in ba_ax.lines[loff:]:
|
|
line.set(lw=0.5, zorder=4, color=after_color)
|
|
for t in list(ba_ax.texts):
|
|
t.remove()
|
|
ba_ax.set(title="", xlabel="")
|
|
ba_ax.set(ylabel=f"{type_titles[ch_type]}\n{unit}")
|
|
last_ax[0] = ba_ax
|
|
key = f"Before (black) and after ({after_name})"
|
|
if key not in first_ax:
|
|
first_ax[key] = ba_ax
|
|
ri += 1
|
|
for ax in last_ax:
|
|
ax.set(xlabel="Time (s)")
|
|
for title, ax in first_ax.items():
|
|
ax.set_title(title, fontsize="medium")
|
|
plt_show(show)
|
|
return fig
|