"""Functions to plot M/EEG data e.g. topographies.""" # Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. import copy import itertools import warnings from functools import partial from numbers import Integral import numpy as np from scipy.interpolate import ( CloughTocher2DInterpolator, LinearNDInterpolator, NearestNDInterpolator, ) from scipy.sparse import csr_array from scipy.spatial import Delaunay, Voronoi from scipy.spatial.distance import pdist, squareform from .._fiff.meas_info import Info, _simplify_info from .._fiff.pick import ( _MEG_CH_TYPES_SPLIT, _pick_data_channels, _picks_by_type, _picks_to_idx, pick_channels, pick_info, pick_types, ) from ..baseline import rescale from ..defaults import ( _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT, _handle_default, ) from ..transforms import apply_trans, invert_transform from ..utils import ( _check_option, _check_sphere, _clean_names, _is_numeric, _time_mask, _validate_type, check_version, fill_doc, legacy, logger, verbose, warn, ) from ..utils.spectrum import _split_psd_kwargs from .ui_events import TimeChange, publish, subscribe from .utils import ( DraggableColorbar, _check_delayed_ssp, _check_time_unit, _check_type_projs, _draw_proj_checkbox, _format_units_psd, _get_cmap, _get_plot_ch_type, _prepare_sensor_names, _prepare_trellis, _process_times, _set_3d_axes_equal, _setup_cmap, _setup_vmin_vmax, _validate_if_list_of_axes, figure_nobar, plot_sensors, plt_show, ) _fnirs_types = ("hbo", "hbr", "fnirs_cw_amplitude", "fnirs_od") # 3.8+ uses a single Collection artist rather than .collections # https://github.com/matplotlib/matplotlib/pull/25247 def _cont_collections(cont): return (cont,) if check_version("matplotlib", "3.8") else tuple(cont.collections) def _adjust_meg_sphere(sphere, info, ch_type): sphere = _check_sphere(sphere, info) assert ch_type is not None if ch_type in ("mag", "grad", "planar1", "planar2"): # move sphere X/Y (head coords) to device X/Y space if info["dev_head_t"] is not None: head_dev_t = invert_transform(info["dev_head_t"]) sphere[:3] = apply_trans(head_dev_t, sphere[:3]) # Set the sphere Z=0 because all this really affects is flattening. # We could make the head size change as a function of depth in # the helmet like: # # sphere[2] /= -5 # # but let's just assume some orthographic rather than parallel # projection for explicitness / simplicity. sphere[2] = 0.0 clip_origin = (0.0, 0.0) else: clip_origin = sphere[:2].copy() return sphere, clip_origin def _prepare_topomap_plot(inst, ch_type, sphere=None): """Prepare topo plot.""" from ..channels.layout import _find_topomap_coords, _pair_grad_sensors, find_layout info = copy.deepcopy(inst if isinstance(inst, Info) else inst.info) sphere, clip_origin = _adjust_meg_sphere(sphere, info, ch_type) clean_ch_names = _clean_names(info["ch_names"]) for ii, this_ch in enumerate(info["chs"]): this_ch["ch_name"] = clean_ch_names[ii] for comp in info["comps"]: comp["data"]["col_names"] = _clean_names(comp["data"]["col_names"]) info._update_redundant() info["bads"] = _clean_names(info["bads"]) info._check_consistency() # special case for merging grad channels layout = find_layout(info) if ( ch_type == "grad" and layout is not None and ( layout.kind.startswith("Vectorview") or layout.kind.startswith("Neuromag_122") ) ): picks, _ = _pair_grad_sensors(info, layout) pos = _find_topomap_coords(info, picks[::2], sphere=sphere) merge_channels = True elif ch_type in _fnirs_types: # fNIRS data commonly has overlapping channels, so deal with separately picks, pos, merge_channels, overlapping_channels = _average_fnirs_overlaps( info, ch_type, sphere ) else: merge_channels = False if ch_type == "eeg": picks = pick_types(info, meg=False, eeg=True, ref_meg=False, exclude="bads") elif ch_type == "csd": picks = pick_types(info, meg=False, csd=True, ref_meg=False, exclude="bads") elif ch_type == "dbs": picks = pick_types(info, meg=False, dbs=True, ref_meg=False, exclude="bads") elif ch_type == "seeg": picks = pick_types( info, meg=False, seeg=True, ref_meg=False, exclude="bads" ) else: picks = pick_types(info, meg=ch_type, ref_meg=False, exclude="bads") if len(picks) == 0: raise ValueError(f"No channels of type {ch_type!r}") pos = _find_topomap_coords(info, picks, sphere=sphere) ch_names = [info["ch_names"][k] for k in picks] if ch_type in _fnirs_types: # Remove the chroma label type for cleaner labeling. ch_names = [k[:-4] for k in ch_names] if merge_channels: if ch_type == "grad": # change names so that vectorview combined grads appear as MEG014x # instead of MEG0142 or MEG0143 which are the 2 planar grads. ch_names = [ch_names[k][:-1] + "x" for k in range(0, len(ch_names), 2)] else: assert ch_type in _fnirs_types # Modify the nirs channel names to indicate they are to be merged # New names will have the form S1_D1xS2_D2 # More than two channels can overlap and be merged for set_ in overlapping_channels: idx = ch_names.index(set_[0][:-4]) new_name = "x".join(s[:-4] for s in set_) ch_names[idx] = new_name pos = np.array(pos)[:, :2] # 2D plot, otherwise interpolation bugs return picks, pos, merge_channels, ch_names, ch_type, sphere, clip_origin def _average_fnirs_overlaps(info, ch_type, sphere): from ..channels.layout import _find_topomap_coords picks = pick_types(info, meg=False, ref_meg=False, fnirs=ch_type, exclude="bads") chs = [info["chs"][i] for i in picks] locs3d = np.array([ch["loc"][:3] for ch in chs]) dist = pdist(locs3d) # Store the sets of channels to be merged overlapping_channels = list() # Channels to be excluded from picks, as will be removed after merging channels_to_exclude = list() if len(locs3d) > 1 and np.min(dist) < 1e-10: overlapping_mask = np.triu(squareform(dist < 1e-10)) for chan_idx in range(overlapping_mask.shape[0]): already_overlapped = list( itertools.chain.from_iterable(overlapping_channels) ) if overlapping_mask[chan_idx].any() and ( chs[chan_idx]["ch_name"] not in already_overlapped ): # Determine the set of channels to be combined. Ensure the # first listed channel is the one to be replaced with merge overlapping_set = [ chs[i]["ch_name"] for i in np.where(overlapping_mask[chan_idx])[0] ] overlapping_set = np.insert( overlapping_set, 0, (chs[chan_idx]["ch_name"]) ) overlapping_channels.append(overlapping_set) channels_to_exclude.append(overlapping_set[1:]) exclude = list(itertools.chain.from_iterable(channels_to_exclude)) [exclude.append(bad) for bad in info["bads"]] picks = pick_types( info, meg=False, ref_meg=False, fnirs=ch_type, exclude=exclude ) pos = _find_topomap_coords(info, picks, sphere=sphere) picks = pick_types(info, meg=False, ref_meg=False, fnirs=ch_type) # Overload the merge_channels variable as this is returned to calling # function and indicates that merging of data is required merge_channels = overlapping_channels else: picks = pick_types( info, meg=False, ref_meg=False, fnirs=ch_type, exclude="bads" ) merge_channels = False pos = _find_topomap_coords(info, picks, sphere=sphere) return picks, pos, merge_channels, overlapping_channels def _plot_update_evoked_topomap(params, bools): """Update topomaps.""" from ..channels.layout import _merge_ch_data projs = [ proj for ii, proj in enumerate(params["projs"]) if ii in np.where(bools)[0] ] params["proj_bools"] = bools new_evoked = params["evoked"].copy() with new_evoked.info._unlock(): new_evoked.info["projs"] = [] new_evoked.add_proj(projs) new_evoked.apply_proj() data = new_evoked.data[:, params["time_idx"]] * params["scale"] if params["merge_channels"]: data, _ = _merge_ch_data(data, "grad", []) interp = params["interp"] new_contours = list() use_contours = params["contours_"] if not len(use_contours): use_contours = [None] * len(params["axes"]) assert len(use_contours) == len(params["images"]) assert len(params["axes"]) == len(params["images"]) assert len(data.T) == len(params["images"]) for cont, ax, im, d in zip(use_contours, params["axes"], params["images"], data.T): Zi = interp.set_values(d)() im.set_data(Zi) if cont is None: continue # must be removed and re-added cont_collections = _cont_collections(cont) for col in cont_collections: col.remove() col = cont_collections[0] lw = col.get_linewidth() visible = col.get_visible() patch_ = col.get_clip_path() color = col.get_edgecolors() cont = ax.contour( interp.Xi, interp.Yi, Zi, params["contours"], colors=color, linewidths=lw ) cont_collections = _cont_collections(cont) for col in cont_collections: col.set_visible(visible) col.set_clip_path(patch_) new_contours.append(cont) params["contours_"] = new_contours params["fig"].canvas.draw() def _add_colorbar( ax, im, cmap, *, title=None, format_=None, kind=None, ch_type=None, ): """Add a colorbar to an axis.""" cbar = ax.figure.colorbar(im, format=format_, shrink=0.6) if cmap is not None and cmap[1]: ax.CB = DraggableColorbar(cbar, im, kind, ch_type) cax = cbar.ax if title is not None: cax.set_title(title, y=1.05, fontsize=10) return cbar, cax def _eliminate_zeros(proj): """Remove grad or mag data if only contains 0s (gh 5641).""" GRAD_ENDING = ("2", "3") MAG_ENDING = "1" proj = copy.deepcopy(proj) proj["data"]["data"] = np.atleast_2d(proj["data"]["data"]) for ending in (GRAD_ENDING, MAG_ENDING): names = proj["data"]["col_names"] idx = [i for i, name in enumerate(names) if name.endswith(ending)] # if all 0, remove the 0s an their labels if not proj["data"]["data"][0][idx].any(): new_col_names = np.delete(np.array(names), idx).tolist() new_data = np.delete(np.array(proj["data"]["data"][0]), idx) proj["data"]["col_names"] = new_col_names proj["data"]["data"] = np.array([new_data]) proj["data"]["ncol"] = len(proj["data"]["col_names"]) return proj @fill_doc def plot_projs_topomap( projs, info, *, sensors=True, show_names=False, contours=6, outlines="head", sphere=None, image_interp=_INTERPOLATION_DEFAULT, extrapolate=_EXTRAPOLATE_DEFAULT, border=_BORDER_DEFAULT, res=64, size=1, cmap=None, vlim=(None, None), cnorm=None, colorbar=False, cbar_fmt="%3.1f", units=None, axes=None, show=True, ): """Plot topographic maps of SSP projections. Parameters ---------- projs : list of Projection The projections. %(info_not_none)s Must be associated with the channels in the projectors. .. versionchanged:: 0.20 The positional argument ``layout`` was replaced by ``info``. %(sensors_topomap)s %(show_names_topomap)s .. versionadded:: 1.2 %(contours_topomap)s %(outlines_topomap)s %(sphere_topomap_auto)s %(image_interp_topomap)s %(extrapolate_topomap)s .. versionadded:: 0.20 .. versionchanged:: 0.21 - The default was changed to ``'local'`` for MEG sensors. - ``'local'`` was changed to use a convex hull mask - ``'head'`` was changed to extrapolate out to the clipping circle. %(border_topomap)s .. versionadded:: 0.20 %(res_topomap)s %(size_topomap)s %(cmap_topomap)s %(vlim_plot_topomap_proj)s %(cnorm)s .. versionadded:: 1.2 %(colorbar_topomap)s %(cbar_fmt_topomap)s .. versionadded:: 1.2 %(units_topomap)s .. versionadded:: 1.2 %(axes_plot_projs_topomap)s %(show)s Returns ------- fig : instance of matplotlib.figure.Figure Figure with a topomap subplot for each projector. Notes ----- .. versionadded:: 0.9.0 """ fig = _plot_projs_topomap( projs, info, sensors=sensors, show_names=show_names, contours=contours, outlines=outlines, sphere=sphere, image_interp=image_interp, extrapolate=extrapolate, border=border, res=res, size=size, cmap=cmap, vlim=vlim, cnorm=cnorm, colorbar=colorbar, cbar_fmt=cbar_fmt, units=units, axes=axes, ) with warnings.catch_warnings(record=True): warnings.simplefilter("ignore") plt_show(show) return fig def _plot_projs_topomap( projs, info, sensors=True, show_names=False, contours=6, outlines="head", sphere=None, image_interp=_INTERPOLATION_DEFAULT, extrapolate=_EXTRAPOLATE_DEFAULT, border=_BORDER_DEFAULT, res=64, size=1, cmap=None, vlim=(None, None), cnorm=None, colorbar=False, cbar_fmt="%3.1f", units=None, axes=None, ): import matplotlib.pyplot as plt from ..channels.layout import _merge_ch_data sphere = _check_sphere(sphere, info) projs = _check_type_projs(projs) _validate_type(info, "info", "info") # Preprocess projs to deal with joint MEG projectors. If we duplicate these and # split into mag and grad, they should work as expected info_names = _clean_names(info["ch_names"], remove_whitespace=True) use_projs = list() for proj in projs: proj = _eliminate_zeros(proj) # gh 5641, makes a copy proj["data"]["col_names"] = _clean_names( proj["data"]["col_names"], remove_whitespace=True, ) picks = pick_channels(info_names, proj["data"]["col_names"], ordered=True) proj_types = info.get_channel_types(picks) unique_types = sorted(set(proj_types)) for type_ in unique_types: proj_picks = np.where([proj_type == type_ for proj_type in proj_types])[0] use_projs.append(copy.deepcopy(proj)) use_projs[-1]["data"]["data"] = proj["data"]["data"][:, proj_picks] use_projs[-1]["data"]["col_names"] = [ proj["data"]["col_names"][pick] for pick in proj_picks ] projs = use_projs datas, poss, spheres, outliness, ch_typess = [], [], [], [], [] for proj in projs: # get ch_names, ch_types, data data = proj["data"]["data"].ravel() picks = pick_channels(info_names, proj["data"]["col_names"], ordered=True) use_info = pick_info(info, picks) these_ch_types = use_info.get_channel_types(unique=True) assert len(these_ch_types) == 1 # should be guaranteed above ch_type = these_ch_types[0] ( data_picks, pos, merge_channels, names, _, this_sphere, clip_origin, ) = _prepare_topomap_plot(use_info, ch_type, sphere=sphere) these_outlines = _make_head_outlines(sphere, pos, outlines, clip_origin) data = data[data_picks] if merge_channels: data, _ = _merge_ch_data(data, "grad", []) data = data.ravel() # populate containers datas.append(data) poss.append(pos) spheres.append(this_sphere) outliness.append(these_outlines) ch_typess.append(ch_type) del data, pos, this_sphere, these_outlines, ch_type del sphere # setup axes n_projs = len(projs) if axes is None: fig, axes, ncols, nrows = _prepare_trellis( n_projs, ncols="auto", nrows="auto", size=size, sharex=True, sharey=True ) elif isinstance(axes, plt.Axes): axes = [axes] _validate_if_list_of_axes(axes, n_projs) # handle vmin/vmax vlims = [None for _ in range(len(datas))] if vlim == "joint": for _ch_type in set(ch_typess): idx = np.where(np.isin(ch_typess, _ch_type))[0] these_data = np.concatenate(np.array(datas, dtype=object)[idx]) norm = all(these_data >= 0) _vl = _setup_vmin_vmax(these_data, vmin=None, vmax=None, norm=norm) for _idx in idx: vlims[_idx] = _vl # make sure we got a vlim for all projs assert all([vl is not None for vl in vlims]) else: vlims = [vlim] * len(datas) # plot for proj, ax, _data, _pos, _vlim, _sphere, _outlines, _ch_type in zip( projs, axes, datas, poss, vlims, spheres, outliness, ch_typess ): # ch_names names = [info["ch_names"][k] for k in _picks_to_idx(info, _ch_type)] names = _prepare_sensor_names(names, show_names) # title title = proj["desc"] title = "\n".join(title[ii : ii + 22] for ii in range(0, len(title), 22)) ax.set_title(title, fontsize=10) # plot im, _ = plot_topomap( _data, _pos[:, :2], vlim=_vlim, cmap=cmap, sensors=sensors, names=names, res=res, axes=ax, outlines=_outlines, contours=contours, cnorm=cnorm, image_interp=image_interp, show=False, extrapolate=extrapolate, sphere=_sphere, border=border, ch_type=_ch_type, ) if colorbar: _add_colorbar( ax, im, cmap, title=units, format_=cbar_fmt, kind="projs_topomap", ch_type=_ch_type, ) return ax.get_figure() def _make_head_outlines(sphere, pos, outlines, clip_origin): """Check or create outlines for topoplot.""" assert isinstance(sphere, np.ndarray) x, y, _, radius = sphere del sphere if outlines in ("head", None): ll = np.linspace(0, 2 * np.pi, 101) head_x = np.cos(ll) * radius + x head_y = np.sin(ll) * radius + y dx = np.exp(np.arccos(np.deg2rad(12)) * 1j) dx, dy = dx.real, dx.imag nose_x = np.array([-dx, 0, dx]) * radius + x nose_y = np.array([dy, 1.15, dy]) * radius + y ear_x = np.array( [0.497, 0.510, 0.518, 0.5299, 0.5419, 0.54, 0.547, 0.532, 0.510, 0.489] ) * (radius * 2) ear_y = ( np.array( [ 0.0555, 0.0775, 0.0783, 0.0746, 0.0555, -0.0055, -0.0932, -0.1313, -0.1384, -0.1199, ] ) * (radius * 2) + y ) if outlines is not None: # Define the outline of the head, ears and nose outlines_dict = dict( head=(head_x, head_y), nose=(nose_x, nose_y), ear_left=(-ear_x + x, ear_y), ear_right=(ear_x + x, ear_y), ) else: outlines_dict = dict() # Make the figure encompass slightly more than all points # We probably want to ensure it always contains our most # extremely positioned channels, so we do: mask_scale = max(1.0, np.linalg.norm(pos, axis=1).max() * 1.01 / radius) outlines_dict["mask_pos"] = (mask_scale * head_x, mask_scale * head_y) clip_radius = radius * mask_scale outlines_dict["clip_radius"] = (clip_radius,) * 2 outlines_dict["clip_origin"] = clip_origin outlines = outlines_dict elif isinstance(outlines, dict): if "mask_pos" not in outlines: raise ValueError("You must specify the coordinates of the image mask.") else: raise ValueError("Invalid value for `outlines`.") return outlines def _draw_outlines(ax, outlines): """Draw the outlines for a topomap.""" from matplotlib import rcParams outlines_ = {k: v for k, v in outlines.items() if k not in ["patch"]} for key, (x_coord, y_coord) in outlines_.items(): if "mask" in key or key in ("clip_radius", "clip_origin"): continue ax.plot( x_coord, y_coord, color=rcParams["axes.edgecolor"], linewidth=1, clip_on=False, ) return outlines_ def _get_extra_points(pos, extrapolate, origin, radii): """Get coordinates of additional interpolation points.""" radii = np.array(radii, float) assert radii.shape == (2,) x, y = origin # auto should be gone by now _check_option("extrapolate", extrapolate, ("head", "box", "local")) # the old method of placement - large box mask_pos = None if extrapolate == "box": extremes = np.array([pos.min(axis=0), pos.max(axis=0)]) diffs = extremes[1] - extremes[0] extremes[0] -= diffs extremes[1] += diffs eidx = np.array( list(itertools.product(*([[0] * (pos.shape[1] - 1) + [1]] * pos.shape[1]))) ) pidx = np.tile(np.arange(pos.shape[1])[np.newaxis], (len(eidx), 1)) outer_pts = extremes[eidx, pidx] return outer_pts, mask_pos, Delaunay(np.concatenate((pos, outer_pts))) # check if positions are colinear: diffs = np.diff(pos, axis=0) with np.errstate(divide="ignore"): slopes = diffs[:, 1] / diffs[:, 0] colinear = (slopes == slopes[0]).all() or np.isinf(slopes).all() # compute median inter-electrode distance if colinear or pos.shape[0] < 4: dim = 1 if diffs[:, 1].sum() > diffs[:, 0].sum() else 0 sorting = np.argsort(pos[:, dim]) pos_sorted = pos[sorting, :] diffs = np.diff(pos_sorted, axis=0) distances = np.linalg.norm(diffs, axis=1) distance = np.median(distances) else: tri = Delaunay(pos, incremental=True) idx1, idx2, idx3 = tri.simplices.T distances = np.concatenate( [ np.linalg.norm(pos[i1, :] - pos[i2, :], axis=1) for i1, i2 in zip([idx1, idx2], [idx2, idx3]) ] ) distance = np.median(distances) if extrapolate == "local": if colinear or pos.shape[0] < 4: # special case for colinear points and when there is too # little points for Delaunay (needs at least 3) edge_points = sorting[[0, -1]] line_len = np.diff(pos[edge_points, :], axis=0) unit_vec = line_len / np.linalg.norm(line_len) * distance unit_vec_par = unit_vec[:, ::-1] * [[-1, 1]] edge_pos = pos[edge_points, :] + np.concatenate( [-unit_vec, unit_vec], axis=0 ) new_pos = np.concatenate( [pos + unit_vec_par, pos - unit_vec_par, edge_pos], axis=0 ) if pos.shape[0] == 3: # there may be some new_pos points that are too close # to the original points new_pos_diff = pos[..., np.newaxis] - new_pos.T[np.newaxis, :] new_pos_diff = np.linalg.norm(new_pos_diff, axis=1) good_extra = (new_pos_diff > 0.5 * distance).all(axis=0) new_pos = new_pos[good_extra] tri = Delaunay(np.concatenate([pos, new_pos], axis=0)) return new_pos, new_pos, tri # get the convex hull of data points from triangulation hull_pos = pos[tri.convex_hull] # extend the convex hull limits outwards a bit channels_center = pos.mean(axis=0) radial_dir = hull_pos - channels_center unit_radial_dir = radial_dir / np.linalg.norm( radial_dir, axis=-1, keepdims=True ) hull_extended = hull_pos + unit_radial_dir * distance mask_pos = hull_pos + unit_radial_dir * distance * 0.5 hull_diff = np.diff(hull_pos, axis=1)[:, 0] hull_distances = np.linalg.norm(hull_diff, axis=-1) del channels_center # Construct a mask mask_pos = np.unique(mask_pos.reshape(-1, 2), axis=0) mask_center = np.mean(mask_pos, axis=0) mask_pos -= mask_center mask_pos = mask_pos[np.argsort(np.arctan2(mask_pos[:, 1], mask_pos[:, 0]))] mask_pos += mask_center # add points along hull edges so that the distance between points # is around that of average distance between channels add_points = list() eps = np.finfo("float").eps n_times_dist = np.round(0.25 * hull_distances / distance).astype("int") for n in range(2, n_times_dist.max() + 1): mask = n_times_dist == n mult = np.arange(1 / n, 1 - eps, 1 / n)[:, np.newaxis, np.newaxis] steps = hull_diff[mask][np.newaxis, ...] * mult add_points.append( (hull_extended[mask, 0][np.newaxis, ...] + steps).reshape((-1, 2)) ) # remove duplicates from hull_extended hull_extended = np.unique(hull_extended.reshape((-1, 2)), axis=0) new_pos = np.concatenate([hull_extended] + add_points) else: assert extrapolate == "head" # return points on the head circle angle = np.arcsin(min(distance / np.mean(radii), 1)) n_pnts = max(12, int(np.round(2 * np.pi / angle))) points_l = np.linspace(0, 2 * np.pi, n_pnts, endpoint=False) use_radii = radii * 1.1 + distance points_x = np.cos(points_l) * use_radii[0] + x points_y = np.sin(points_l) * use_radii[1] + y new_pos = np.stack([points_x, points_y], axis=1) if colinear or pos.shape[0] == 3: tri = Delaunay(np.concatenate([pos, new_pos], axis=0)) return new_pos, mask_pos, tri tri.add_points(new_pos) return new_pos, mask_pos, tri class _GridData: """Unstructured (x,y) data interpolator. This class allows optimized interpolation by computing parameters for a fixed set of true points, and allowing the values at those points to be set independently. """ def __init__(self, pos, image_interp, extrapolate, origin, radii, border): # in principle this works in N dimensions, not just 2 assert pos.ndim == 2 and pos.shape[1] == 2, pos.shape _validate_type(border, ("numeric", str), "border") # check that border, if string, is correct if isinstance(border, str): _check_option("border", border, ("mean",), extra="when a string") # Adding points outside the extremes helps the interpolators outer_pts, mask_pts, tri = _get_extra_points(pos, extrapolate, origin, radii) self.n_extra = outer_pts.shape[0] self.mask_pts = mask_pts self.border = border self.tri = tri self.interp = { "cubic": CloughTocher2DInterpolator, "nearest": NearestNDInterpolator, # used only for anim "linear": LinearNDInterpolator, }[image_interp] def set_values(self, v): """Set the values at interpolation points.""" # Rbf with thin-plate is what we used to use, but it's slower and # looks about the same: # # zi = Rbf(x, y, v, function='multiquadric', smooth=0)(xi, yi) # # Eventually we could also do set_values with this class if we want, # see scipy/interpolate/rbf.py, especially the self.nodes one-liner. if isinstance(self.border, str): # we've already checked that border = 'mean' n_points = v.shape[0] v_extra = np.zeros(self.n_extra) indices, indptr = self.tri.vertex_neighbor_vertices rng = range(n_points, n_points + self.n_extra) used = np.zeros(len(rng), bool) for idx, extra_idx in enumerate(rng): ngb = indptr[indices[extra_idx] : indices[extra_idx + 1]] ngb = ngb[ngb < n_points] if len(ngb) > 0: used[idx] = True v_extra[idx] = v[ngb].mean() if not used.all() and used.any(): # Eventually we might want to use the value of the nearest # point or something, but this case should hopefully be # rare so for now just use the average value of all extras v_extra[~used] = np.mean(v_extra[used]) else: v_extra = np.full(self.n_extra, self.border, dtype=float) v = np.concatenate((v, v_extra)) self.interpolator = self.interp(self.tri, v) return self def set_locations(self, Xi, Yi): """Set locations for easier (delayed) calling.""" self.Xi = Xi self.Yi = Yi return self def __call__(self, *args): """Evaluate the interpolator.""" if len(args) == 0: args = [self.Xi, self.Yi] return self.interpolator(*args) def _topomap_plot_sensors(pos_x, pos_y, sensors, ax): """Plot sensors.""" if sensors is True: ax.scatter( pos_x, pos_y, s=0.25, marker="o", edgecolor=["k"] * len(pos_x), facecolor="none", ) else: ax.plot(pos_x, pos_y, sensors) def _get_pos_outlines(info, picks, sphere, to_sphere=True): from ..channels.layout import _find_topomap_coords picks = _picks_to_idx(info, picks, "all", exclude=(), allow_empty=False) ch_type = _get_plot_ch_type(pick_info(_simplify_info(info), picks), None) orig_sphere = sphere sphere, clip_origin = _adjust_meg_sphere(sphere, info, ch_type) logger.debug( "Generating pos outlines with sphere " f"{sphere} from {orig_sphere} for {ch_type}" ) pos = _find_topomap_coords( info, picks, ignore_overlap=True, to_sphere=to_sphere, sphere=sphere ) outlines = _make_head_outlines(sphere, pos, "head", clip_origin) return pos, outlines @fill_doc def plot_topomap( data, pos, *, ch_type="eeg", sensors=True, names=None, mask=None, mask_params=None, contours=6, outlines="head", sphere=None, image_interp=_INTERPOLATION_DEFAULT, extrapolate=_EXTRAPOLATE_DEFAULT, border=_BORDER_DEFAULT, res=64, size=1, cmap=None, vlim=(None, None), cnorm=None, axes=None, show=True, onselect=None, ): """Plot a topographic map as image. Parameters ---------- data : array, shape (n_chan,) The data values to plot. %(pos_topomap)s If an :class:`~mne.Info` object it must contain only one channel type and exactly ``len(data)`` channels; the x/y coordinates will be inferred from the montage in the :class:`~mne.Info` object. %(ch_type_topomap)s .. versionadded:: 0.21 %(sensors_topomap)s %(names_topomap)s %(mask_topomap)s %(mask_params_topomap)s %(contours_topomap)s %(outlines_topomap)s %(sphere_topomap_auto)s %(image_interp_topomap)s %(extrapolate_topomap)s .. versionadded:: 0.18 .. versionchanged:: 0.21 - The default was changed to ``'local'`` for MEG sensors. - ``'local'`` was changed to use a convex hull mask - ``'head'`` was changed to extrapolate out to the clipping circle. %(border_topomap)s .. versionadded:: 0.20 %(res_topomap)s %(size_topomap)s %(cmap_topomap)s %(vlim_plot_topomap)s .. versionadded:: 1.2 %(cnorm)s .. versionadded:: 0.24 %(axes_plot_topomap)s .. versionchanged:: 1.2 If ``axes=None``, a new :class:`~matplotlib.figure.Figure` is created instead of plotting into the current axes. %(show)s onselect : callable | None A function to be called when the user selects a set of channels by click-dragging (uses a matplotlib :class:`~matplotlib.widgets.RectangleSelector`). If ``None`` interactive channel selection is disabled. Defaults to ``None``. Returns ------- im : matplotlib.image.AxesImage The interpolated data. cn : matplotlib.contour.ContourSet The fieldlines. """ import matplotlib.pyplot as plt from matplotlib.colors import Normalize if axes is None: _, axes = plt.subplots(figsize=(size, size), layout="constrained") sphere = _check_sphere(sphere, pos if isinstance(pos, Info) else None) _validate_type(cnorm, (Normalize, None), "cnorm") if cnorm is not None and (vlim[0] is not None or vlim[1] is not None): warn( f"Provided cnorm implicitly defines vmin={cnorm.vmin} and " f"vmax={cnorm.vmax}; ignoring additional vlim/vmin/vmax params." ) return _plot_topomap( data, pos, vmin=vlim[0], vmax=vlim[1], cmap=cmap, sensors=sensors, res=res, axes=axes, names=names, mask=mask, mask_params=mask_params, outlines=outlines, contours=contours, image_interp=image_interp, show=show, onselect=onselect, extrapolate=extrapolate, sphere=sphere, border=border, ch_type=ch_type, cnorm=cnorm, )[:2] def _setup_interp(pos, res, image_interp, extrapolate, outlines, border): if image_interp not in ("cubic", "linear", "nearest"): raise RuntimeError( "`image_interp` must be `cubic`, `linear` or `nearest`, got " f"{image_interp}. Previously, `image_interp` controlled " "the matplotlib image interpolation after an original cubic " "interpolation but this was changed to control the first " "interpolation instead for simplicity. To change the " "matplotlib image interpolation, use " "`im.set_interpolation(...)" ) logger.debug( f"Interpolation mode {image_interp}, " f"extrapolation mode {extrapolate} to {border}" ) xlim = ( np.inf, -np.inf, ) ylim = ( np.inf, -np.inf, ) mask_ = np.c_[outlines["mask_pos"]] clip_radius = outlines["clip_radius"] clip_origin = outlines.get("clip_origin", (0.0, 0.0)) xmin, xmax = ( np.min(np.r_[xlim[0], mask_[:, 0], clip_origin[0] - clip_radius[0]]), np.max(np.r_[xlim[1], mask_[:, 0], clip_origin[0] + clip_radius[0]]), ) ymin, ymax = ( np.min(np.r_[ylim[0], mask_[:, 1], clip_origin[1] - clip_radius[1]]), np.max(np.r_[ylim[1], mask_[:, 1], clip_origin[1] + clip_radius[1]]), ) xi = np.linspace(xmin, xmax, res) yi = np.linspace(ymin, ymax, res) Xi, Yi = np.meshgrid(xi, yi) interp = _GridData(pos, image_interp, extrapolate, clip_origin, clip_radius, border) extent = (xmin, xmax, ymin, ymax) return extent, Xi, Yi, interp _VORONOI_CIRCLE_RES = 100 def _voronoi_topomap(data, pos, outlines, ax, cmap, norm, extent, res): """Make a Voronoi diagram on a topomap.""" # we need an image axis object so first empty image to plot over im = ax.imshow( np.zeros((res, res)) * np.nan, cmap=cmap, origin="lower", aspect="equal", extent=extent, norm=norm, ) rx, ry = outlines["clip_radius"] cx, cy = outlines.get("clip_origin", (0.0, 0.0)) # add points on the circle to make boundaries, expand out to # ensure regions extend to the edge of the topomap vor = Voronoi( np.concatenate( [ pos, [ ( rx * 1.5 * np.cos(2 * np.pi / _VORONOI_CIRCLE_RES * t), ry * 1.5 * np.sin(2 * np.pi / _VORONOI_CIRCLE_RES * t), ) for t in range(_VORONOI_CIRCLE_RES) ], ] ) ) for point_idx, region_idx in enumerate(vor.point_region[:-_VORONOI_CIRCLE_RES]): if -1 in vor.regions[region_idx]: continue polygon = list() for i in vor.regions[region_idx]: x, y = vor.vertices[i] if (x - cx) ** 2 / rx**2 + (y - cy) ** 2 / ry**2 < 1: polygon.append((x, y)) else: x *= rx / np.linalg.norm(vor.vertices[i]) y *= ry / np.linalg.norm(vor.vertices[i]) polygon.append((x, y)) ax.fill(*zip(*polygon), color=cmap(norm(data[point_idx]))) return im def _get_patch(outlines, extrapolate, interp, ax): from matplotlib import patches clip_radius = outlines["clip_radius"] clip_origin = outlines.get("clip_origin", (0.0, 0.0)) _use_default_outlines = any(k.startswith("head") for k in outlines) patch_ = None if "patch" in outlines: patch_ = outlines["patch"] patch_ = patch_() if callable(patch_) else patch_ patch_.set_clip_on(False) ax.add_patch(patch_) ax.set_transform(ax.transAxes) ax.set_clip_path(patch_) if _use_default_outlines: if extrapolate == "local": patch_ = patches.Polygon( interp.mask_pts, clip_on=True, transform=ax.transData ) else: patch_ = patches.Ellipse( clip_origin, 2 * clip_radius[0], 2 * clip_radius[1], clip_on=True, transform=ax.transData, ) return patch_ def _plot_topomap( data, pos, axes, *, ch_type="eeg", sensors=True, names=None, mask=None, mask_params=None, contours=6, outlines="head", sphere=None, image_interp=_INTERPOLATION_DEFAULT, extrapolate=_EXTRAPOLATE_DEFAULT, border=_BORDER_DEFAULT, res=64, cmap=None, vmin=None, vmax=None, cnorm=None, show=True, onselect=None, ): from matplotlib.colors import Normalize from matplotlib.widgets import RectangleSelector from ..channels.layout import ( _find_topomap_coords, _merge_ch_data, _pair_grad_sensors, ) data = np.asarray(data) logger.debug(f"Plotting topomap for {ch_type} data shape {data.shape}") if isinstance(pos, Info): # infer pos from Info object picks = _pick_data_channels(pos, exclude=()) # pick only data channels pos = pick_info(pos, picks) # check if there is only 1 channel type, and n_chans matches the data ch_type = pos.get_channel_types(picks=None, unique=True) info_help = "Pick Info with e.g. mne.pick_info and mne.channel_indices_by_type." if len(ch_type) > 1: raise ValueError("Multiple channel types in Info structure. " + info_help) elif len(pos["chs"]) != data.shape[0]: raise ValueError( f"Number of channels in the Info object ({len(pos['chs'])}) and the " f"data array ({data.shape[0]}) do not match." + info_help ) else: ch_type = ch_type.pop() if any(type_ in ch_type for type_ in ("planar", "grad")): # deal with grad pairs picks = _pair_grad_sensors(pos, topomap_coords=False) pos = _find_topomap_coords(pos, picks=picks[::2], sphere=sphere) data, _ = _merge_ch_data(data[picks], ch_type, []) data = data.reshape(-1) else: picks = list(range(data.shape[0])) pos = _find_topomap_coords(pos, picks=picks, sphere=sphere) extrapolate = _check_extrapolate(extrapolate, ch_type) if data.ndim > 1: raise ValueError( f"Data needs to be array of shape (n_sensors,); got shape {data.shape}." ) # Give a helpful error message for common mistakes regarding the position # matrix. pos_help = ( "Electrode positions should be specified as a 2D array with " "shape (n_channels, 2). Each row in this matrix contains the " "(x, y) position of an electrode." ) if pos.ndim != 2: error = ( f"{pos.ndim}D array supplied as electrode positions, where a 2D array was " "expected" ) raise ValueError(error + " " + pos_help) elif pos.shape[1] == 3: error = ( "The supplied electrode positions matrix contains 3 columns. " "Are you trying to specify XYZ coordinates? Perhaps the " "mne.channels.create_eeg_layout function is useful for you." ) raise ValueError(error + " " + pos_help) # No error is raised in case of pos.shape[1] == 4. In this case, it is # assumed the position matrix contains both (x, y) and (width, height) # values, such as Layout.pos. elif pos.shape[1] == 1 or pos.shape[1] > 4: raise ValueError(pos_help) pos = pos[:, :2] if len(data) != len(pos): raise ValueError( "Data and pos need to be of same length. Got data of " f"length {len(data)}, pos of length { len(pos)}" ) norm = min(data) >= 0 vmin, vmax = _setup_vmin_vmax(data, vmin, vmax, norm) if cmap is None: cmap = "Reds" if norm else "RdBu_r" cmap = _get_cmap(cmap) outlines = _make_head_outlines(sphere, pos, outlines, (0.0, 0.0)) assert isinstance(outlines, dict) _prepare_topomap(pos, axes) mask_params = _handle_default("mask_params", mask_params) # find mask limits and setup interpolation extent, Xi, Yi, interp = _setup_interp( pos, res, image_interp, extrapolate, outlines, border ) interp.set_values(data) Zi = interp.set_locations(Xi, Yi)() # plot outline patch_ = _get_patch(outlines, extrapolate, interp, axes) # get colormap normalization if cnorm is None: cnorm = Normalize(vmin=vmin, vmax=vmax) # plot interpolated map if image_interp == "nearest": # plot over with Voronoi, more accurate im = _voronoi_topomap( data, pos=pos, outlines=outlines, ax=axes, cmap=cmap, norm=cnorm, extent=extent, res=res, ) else: im = axes.imshow( Zi, cmap=cmap, origin="lower", aspect="equal", extent=extent, interpolation="bilinear", norm=cnorm, ) # gh-1432 had a workaround for no contours here, but we'll remove it # because mpl has probably fixed it linewidth = mask_params["markeredgewidth"] cont = True if isinstance(contours, (np.ndarray, list)): pass elif contours == 0 or ((Zi == Zi[0, 0]) | np.isnan(Zi)).all(): cont = None # can't make contours for constant-valued functions if cont: with warnings.catch_warnings(record=True): warnings.simplefilter("ignore") cont = axes.contour( Xi, Yi, Zi, contours, colors="k", linewidths=linewidth / 2.0 ) if patch_ is not None: im.set_clip_path(patch_) if cont is not None: for col in _cont_collections(cont): col.set_clip_path(patch_) pos_x, pos_y = pos.T mask = mask.astype(bool, copy=False) if mask is not None else None if sensors is not False and mask is None: _topomap_plot_sensors(pos_x, pos_y, sensors=sensors, ax=axes) elif sensors and mask is not None: idx = np.where(mask)[0] axes.plot(pos_x[idx], pos_y[idx], **mask_params) idx = np.where(~mask)[0] _topomap_plot_sensors(pos_x[idx], pos_y[idx], sensors=sensors, ax=axes) elif not sensors and mask is not None: idx = np.where(mask)[0] axes.plot(pos_x[idx], pos_y[idx], **mask_params) if isinstance(outlines, dict): _draw_outlines(axes, outlines) if names is not None: show_idx = np.arange(len(names)) if mask is None else np.where(mask)[0] for ii, (_pos, _name) in enumerate(zip(pos, names)): if ii not in show_idx: continue axes.text( _pos[0], _pos[1], _name, horizontalalignment="center", verticalalignment="center", size="x-small", ) if onselect is not None: lim = axes.dataLim x0, y0, width, height = lim.x0, lim.y0, lim.width, lim.height axes.RS = RectangleSelector(axes, onselect=onselect) axes.set(xlim=[x0, x0 + width], ylim=[y0, y0 + height]) plt_show(show) return im, cont, interp def _plot_ica_topomap( ica, idx=0, ch_type=None, res=64, vmin=None, vmax=None, cmap="RdBu_r", colorbar=False, title=None, show=True, outlines="head", contours=6, image_interp=_INTERPOLATION_DEFAULT, axes=None, sensors=True, allow_ref_meg=False, extrapolate=_EXTRAPOLATE_DEFAULT, sphere=None, border=_BORDER_DEFAULT, ): """Plot single ica map to axes.""" from matplotlib.axes import Axes from ..channels.layout import _merge_ch_data if ica.info is None: raise RuntimeError( "The ICA's measurement info is missing. Please " "fit the ICA or add the corresponding info object." ) sphere = _check_sphere(sphere, ica.info) if not isinstance(axes, Axes): raise ValueError( "axis has to be an instance of matplotlib Axes, " f"got {type(axes)} instead." ) ch_type = _get_plot_ch_type(ica, ch_type, allow_ref_meg=ica.allow_ref_meg) if ch_type == "ref_meg": logger.info("Cannot produce topographies for MEG reference channels.") return data = ica.get_components()[:, idx] ( data_picks, pos, merge_channels, names, _, sphere, clip_origin, ) = _prepare_topomap_plot(ica, ch_type, sphere=sphere) data = data[data_picks] outlines = _make_head_outlines(sphere, pos, outlines, clip_origin) if merge_channels: data, names = _merge_ch_data(data, ch_type, names) topo_title = ica._ica_names[idx] if len(set(ica.get_channel_types())) > 1: topo_title += f" ({ch_type})" axes.set_title(topo_title, fontsize=12) vlim = _setup_vmin_vmax(data, vmin, vmax) im = plot_topomap( data.ravel(), pos, vlim=vlim, res=res, axes=axes, cmap=cmap, outlines=outlines, contours=contours, sensors=sensors, image_interp=image_interp, show=show, extrapolate=extrapolate, sphere=sphere, border=border, ch_type=ch_type, )[0] if colorbar: cbar, cax = _add_colorbar( axes, im, cmap, title="AU", format_="%3.2f", kind="ica_topomap", ch_type=ch_type, ) cbar.ax.tick_params(labelsize=12) cbar.set_ticks(vlim) _hide_frame(axes) @verbose def plot_ica_components( ica, picks=None, ch_type=None, *, inst=None, plot_std=True, reject="auto", sensors=True, show_names=False, contours=6, outlines="head", sphere=None, image_interp=_INTERPOLATION_DEFAULT, extrapolate=_EXTRAPOLATE_DEFAULT, border=_BORDER_DEFAULT, res=64, size=1, cmap="RdBu_r", vlim=(None, None), cnorm=None, colorbar=False, cbar_fmt="%3.2f", axes=None, title=None, nrows="auto", ncols="auto", show=True, image_args=None, psd_args=None, verbose=None, ): """Project mixing matrix on interpolated sensor topography. Parameters ---------- ica : instance of mne.preprocessing.ICA The ICA solution. %(picks_ica)s %(ch_type_topomap)s inst : Raw | Epochs | None To be able to see component properties after clicking on component topomap you need to pass relevant data - instances of Raw or Epochs (for example the data that ICA was trained on). This takes effect only when running matplotlib in interactive mode. plot_std : bool | float Whether to plot standard deviation in ERP/ERF and spectrum plots. Defaults to True, which plots one standard deviation above/below. If set to float allows to control how many standard deviations are plotted. For example 2.5 will plot 2.5 standard deviation above/below. reject : ``'auto'`` | dict | None Allows to specify rejection parameters used to drop epochs (or segments if continuous signal is passed as inst). If None, no rejection is applied. The default is 'auto', which applies the rejection parameters used when fitting the ICA object. %(sensors_topomap)s %(show_names_topomap)s %(contours_topomap)s %(outlines_topomap)s %(sphere_topomap_auto)s %(image_interp_topomap)s %(extrapolate_topomap)s .. versionadded:: 1.3 %(border_topomap)s .. versionadded:: 1.3 %(res_topomap)s %(size_topomap)s .. versionadded:: 1.3 %(cmap_topomap)s %(vlim_plot_topomap)s .. versionadded:: 1.3 %(cnorm)s .. versionadded:: 1.3 %(colorbar_topomap)s %(cbar_fmt_topomap)s axes : Axes | array of Axes | None The subplot(s) to plot to. Either a single Axes or an iterable of Axes if more than one subplot is needed. The number of subplots must match the number of selected components. If None, new figures will be created with the number of subplots per figure controlled by ``nrows`` and ``ncols``. title : str | None The title of the generated figure. If ``None`` (default) and ``axes=None``, a default title of "ICA Components" will be used. %(nrows_ncols_ica_components)s .. versionadded:: 1.3 %(show)s image_args : dict | None Dictionary of arguments to pass to :func:`~mne.viz.plot_epochs_image` in interactive mode. Ignored if ``inst`` is not supplied. If ``None``, nothing is passed. Defaults to ``None``. psd_args : dict | None Dictionary of arguments to pass to :meth:`~mne.Epochs.compute_psd` in interactive mode. Ignored if ``inst`` is not supplied. If ``None``, nothing is passed. Defaults to ``None``. %(verbose)s Returns ------- fig : instance of matplotlib.figure.Figure | list of matplotlib.figure.Figure The figure object(s). Notes ----- When run in interactive mode, ``plot_ica_components`` allows to reject components by clicking on their title label. The state of each component is indicated by its label color (gray: rejected; black: retained). It is also possible to open component properties by clicking on the component topomap (this option is only available when the ``inst`` argument is supplied). """ # noqa E501 from matplotlib.pyplot import Axes from ..channels.layout import _merge_ch_data from ..epochs import BaseEpochs from ..io import BaseRaw if ica.info is None: raise RuntimeError( "The ICA's measurement info is missing. Please " "fit the ICA or add the corresponding info object." ) # for backward compat, nrow='auto' ncol='auto' should yield 4 rows 5 cols # and create multiple figures if more than 20 components requested if nrows == "auto" and ncols == "auto": ncols = 5 max_subplots = 20 elif nrows == "auto" or ncols == "auto": # user provided incomplete row/col spec; put all in one figure max_subplots = ica.n_components_ else: max_subplots = nrows * ncols # handle ch_type=None ch_type = _get_plot_ch_type(ica, ch_type) figs = [] if picks is None: cut_points = range(max_subplots, ica.n_components_, max_subplots) pick_groups = np.split(range(ica.n_components_), cut_points) else: pick_groups = [_picks_to_idx(ica.n_components_, picks, picks_on="components")] axes = axes.flatten() if isinstance(axes, np.ndarray) else axes for k, picks in enumerate(pick_groups): try: # either an iterable, 1D numpy array or others _axes = axes[k * max_subplots : (k + 1) * max_subplots] except TypeError: # None or Axes _axes = axes ( data_picks, pos, merge_channels, names, ch_type, sphere, clip_origin, ) = _prepare_topomap_plot(ica, ch_type, sphere=sphere) cmap = _setup_cmap(cmap, n_axes=len(picks)) names = _prepare_sensor_names(names, show_names) outlines = _make_head_outlines(sphere, pos, outlines, clip_origin) data = np.dot( ica.mixing_matrix_[:, picks].T, ica.pca_components_[: ica.n_components_] ) data = np.atleast_2d(data) data = data[:, data_picks] if title is None: title = "ICA components" user_passed_axes = _axes is not None if not user_passed_axes: fig, _axes, _, _ = _prepare_trellis(len(data), ncols=ncols, nrows=nrows) fig.suptitle(title) else: _axes = [_axes] if isinstance(_axes, Axes) else _axes fig = _axes[0].get_figure() subplot_titles = list() for ii, data_, ax in zip(picks, data, _axes): kwargs = dict(color="gray") if ii in ica.exclude else dict() comp_title = ica._ica_names[ii] if len(set(ica.get_channel_types())) > 1: comp_title += f" ({ch_type})" subplot_titles.append(ax.set_title(comp_title, fontsize=12, **kwargs)) if merge_channels: data_, names_ = _merge_ch_data(data_, ch_type, copy.copy(names)) # ↓↓↓ NOTE: we intentionally use the default norm=False here, so that # ↓↓↓ we get vlims that are symmetric-about-zero, even if the data for # ↓↓↓ a given component happens to be one-sided. _vlim = _setup_vmin_vmax(data_, *vlim) im = plot_topomap( data_.flatten(), pos, ch_type=ch_type, sensors=sensors, names=names, contours=contours, outlines=outlines, sphere=sphere, image_interp=image_interp, extrapolate=extrapolate, border=border, res=res, size=size, cmap=cmap[0], vlim=_vlim, cnorm=cnorm, axes=ax, show=False, )[0] im.axes.set_label(ica._ica_names[ii]) if colorbar: cbar, cax = _add_colorbar( ax, im, cmap, title="AU", format_=cbar_fmt, kind="ica_comp_topomap", ch_type=ch_type, ) cbar.ax.tick_params(labelsize=12) cbar.set_ticks(_vlim) _hide_frame(ax) del pos fig.canvas.draw() # add title selection interactivity def onclick_title(event, ica=ica, titles=subplot_titles, fig=fig): # check if any title was pressed title_pressed = None for title in titles: if title.contains(event)[0]: title_pressed = title break # title was pressed -> identify the IC if title_pressed is not None: label = title_pressed.get_text() ic = int(label.split(" ")[0][-3:]) # add or remove IC from exclude depending on current state if ic in ica.exclude: ica.exclude.remove(ic) title_pressed.set_color("k") else: ica.exclude.append(ic) title_pressed.set_color("gray") fig.canvas.draw() fig.canvas.mpl_connect("button_press_event", onclick_title) # add plot_properties interactivity only if inst was passed if isinstance(inst, (BaseRaw, BaseEpochs)): topomap_args = dict( sensors=sensors, contours=contours, outlines=outlines, sphere=sphere, image_interp=image_interp, extrapolate=extrapolate, border=border, res=res, cmap=cmap[0], vmin=vlim[0], vmax=vlim[1], ) def onclick_topo(event, ica=ica, inst=inst): # check which component to plot if event.inaxes is not None: label = event.inaxes.get_label() if label.startswith("ICA"): ic = int(label.split(" ")[0][-3:]) ica.plot_properties( inst, picks=ic, show=True, plot_std=plot_std, topomap_args=topomap_args, image_args=image_args, psd_args=psd_args, reject=reject, ) fig.canvas.mpl_connect("button_press_event", onclick_topo) figs.append(fig) plt_show(show) return figs[0] if len(figs) == 1 else figs @fill_doc def plot_tfr_topomap( tfr, tmin=None, tmax=None, fmin=0.0, fmax=np.inf, *, ch_type=None, baseline=None, mode="mean", sensors=True, show_names=False, mask=None, mask_params=None, contours=6, outlines="head", sphere=None, image_interp=_INTERPOLATION_DEFAULT, extrapolate=_EXTRAPOLATE_DEFAULT, border=_BORDER_DEFAULT, res=64, size=2, cmap=None, vlim=(None, None), cnorm=None, colorbar=True, cbar_fmt="%1.1e", units=None, axes=None, show=True, ): """Plot topographic maps of specific time-frequency intervals of TFR data. Parameters ---------- tfr : AverageTFR The AverageTFR object. %(tmin_tmax_psd)s %(fmin_fmax_psd)s %(ch_type_topomap_psd)s baseline : tuple or list of length 2 The time interval to apply rescaling / baseline correction. If None do not apply it. If baseline is (a, b) the interval is between "a (s)" and "b (s)". If a is None the beginning of the data is used and if b is None then b is set to the end of the interval. If baseline is equal to (None, None) the whole time interval is used. mode : 'mean' | 'ratio' | 'logratio' | 'percent' | 'zscore' | 'zlogratio' | None Perform baseline correction by - subtracting the mean baseline power ('mean') - dividing by the mean baseline power ('ratio') - dividing by the mean baseline power and taking the log ('logratio') - subtracting the mean baseline power followed by dividing by the mean baseline power ('percent') - subtracting the mean baseline power and dividing by the standard deviation of the baseline power ('zscore') - dividing by the mean baseline power, taking the log, and dividing by the standard deviation of the baseline power ('zlogratio') If None no baseline correction is applied. %(sensors_topomap)s %(show_names_topomap)s %(mask_evoked_topomap)s %(mask_params_topomap)s %(contours_topomap)s %(outlines_topomap)s %(sphere_topomap_auto)s %(image_interp_topomap)s %(extrapolate_topomap)s .. versionchanged:: 0.21 - The default was changed to ``'local'`` for MEG sensors. - ``'local'`` was changed to use a convex hull mask - ``'head'`` was changed to extrapolate out to the clipping circle. %(border_topomap)s .. versionadded:: 0.20 %(res_topomap)s %(size_topomap)s %(cmap_topomap)s %(vlim_plot_topomap)s .. versionadded:: 1.2 %(cnorm)s .. versionadded:: 1.2 %(colorbar_topomap)s %(cbar_fmt_topomap)s %(units_topomap)s %(axes_plot_topomap)s %(show)s Returns ------- fig : matplotlib.figure.Figure The figure containing the topography. """ # noqa: E501 import matplotlib.pyplot as plt from ..channels.layout import _merge_ch_data ch_type = _get_plot_ch_type(tfr, ch_type) picks, pos, merge_channels, names, _, sphere, clip_origin = _prepare_topomap_plot( tfr, ch_type, sphere=sphere ) outlines = _make_head_outlines(sphere, pos, outlines, clip_origin) data = tfr.data[picks, :, :] # merging grads before rescaling makes ERDs visible if merge_channels: data, names = _merge_ch_data(data, ch_type, names, method="mean") data = rescale(data, tfr.times, baseline, mode, copy=True) if np.iscomplexobj(data): data = np.sqrt((data * data.conj()).real) # crop time itmin, itmax = None, None idx = np.where(_time_mask(tfr.times, tmin, tmax))[0] if tmin is not None: itmin = idx[0] if tmax is not None: itmax = idx[-1] + 1 # crop freqs ifmin, ifmax = None, None idx = np.where(_time_mask(tfr.freqs, fmin, fmax))[0] ifmin = idx[0] ifmax = idx[-1] + 1 data = data[:, ifmin:ifmax, itmin:itmax] data = data.mean(axis=(1, 2))[:, np.newaxis] norm = False if np.min(data) < 0 else True vlim = _setup_vmin_vmax(data, *vlim, norm) cmap = _setup_cmap(cmap, norm=norm) axes = ( plt.subplots(figsize=(size, size), layout="constrained")[1] if axes is None else axes ) fig = axes.figure _hide_frame(axes) locator = None if not isinstance(contours, (list, np.ndarray)): locator, contours = _set_contour_locator(*vlim, contours) fig_wrapper = list() selection_callback = partial( _onselect, tfr=tfr, pos=pos, ch_type=ch_type, itmin=itmin, itmax=itmax, ifmin=ifmin, ifmax=ifmax, cmap=cmap[0], fig=fig_wrapper, ) if not isinstance(contours, (list, np.ndarray)): _, contours = _set_contour_locator(*vlim, contours) names = _prepare_sensor_names(names, show_names) im, _ = plot_topomap( data[:, 0], pos, ch_type=ch_type, sensors=sensors, names=names, mask=mask, mask_params=mask_params, contours=contours, outlines=outlines, sphere=sphere, image_interp=image_interp, extrapolate=extrapolate, border=border, res=res, size=size, cmap=cmap[0], vlim=vlim, cnorm=cnorm, axes=axes, show=False, onselect=selection_callback, ) if colorbar: from matplotlib import ticker units = _handle_default("units", units)["misc"] cbar, cax = _add_colorbar( axes, im, cmap, title=units, format_=cbar_fmt, kind="tfr_topomap", ch_type=ch_type, ) if locator is None: locator = ticker.MaxNLocator(nbins=5) cbar.locator = locator cbar.update_ticks() cbar.ax.tick_params(labelsize=12) plt_show(show) return fig @fill_doc def plot_evoked_topomap( evoked, times="auto", *, average=None, ch_type=None, scalings=None, proj=False, sensors=True, show_names=False, mask=None, mask_params=None, contours=6, outlines="head", sphere=None, image_interp=_INTERPOLATION_DEFAULT, extrapolate=_EXTRAPOLATE_DEFAULT, border=_BORDER_DEFAULT, res=64, size=1, cmap=None, vlim=(None, None), cnorm=None, colorbar=True, cbar_fmt="%3.1f", units=None, axes=None, time_unit="s", time_format=None, nrows=1, ncols="auto", show=True, ): """Plot topographic maps of specific time points of evoked data. Parameters ---------- evoked : Evoked The Evoked object. times : float | array of float | "auto" | "peaks" | "interactive" The time point(s) to plot. If "auto", the number of ``axes`` determines the amount of time point(s). If ``axes`` is also None, at most 10 topographies will be shown with a regular time spacing between the first and last time instant. If "peaks", finds time points automatically by checking for local maxima in global field power. If "interactive", the time can be set interactively at run-time by using a slider. %(average_plot_evoked_topomap)s %(ch_type_topomap)s %(scalings_topomap)s %(proj_plot)s %(sensors_topomap)s %(show_names_topomap)s %(mask_evoked_topomap)s %(mask_params_topomap)s %(contours_topomap)s %(outlines_topomap)s %(sphere_topomap_auto)s %(image_interp_topomap)s %(extrapolate_topomap)s .. versionadded:: 0.18 .. versionchanged:: 0.21 - The default was changed to ``'local'`` for MEG sensors. - ``'local'`` was changed to use a convex hull mask - ``'head'`` was changed to extrapolate out to the clipping circle. %(border_topomap)s .. versionadded:: 0.20 %(res_topomap)s %(size_topomap)s %(cmap_topomap)s %(vlim_plot_topomap_psd)s .. versionadded:: 1.2 %(cnorm)s .. versionadded:: 1.2 %(colorbar_topomap)s %(cbar_fmt_topomap)s %(units_topomap_evoked)s %(axes_evoked_plot_topomap)s time_unit : str The units for the time axis, can be "ms" or "s" (default). .. versionadded:: 0.16 time_format : str | None String format for topomap values. Defaults (None) to "%%01d ms" if ``time_unit='ms'``, "%%0.3f s" if ``time_unit='s'``, and "%%g" otherwise. Can be an empty string to omit the time label. %(nrows_ncols_topomap)s Ignored when times == 'interactive'. .. versionadded:: 0.20 %(show)s Returns ------- fig : instance of matplotlib.figure.Figure The figure. Notes ----- When existing ``axes`` are provided and ``colorbar=True``, note that the colorbar scale will only accurately reflect topomaps that are generated in the same call as the colorbar. Note also that the colorbar will not be resized automatically when ``axes`` are provided; use Matplotlib's :meth:`axes.set_position() ` method or :ref:`gridspec ` interface to adjust the colorbar size yourself. When ``time=="interactive"``, the figure will publish and subscribe to the following UI events: * :class:`~mne.viz.ui_events.TimeChange` whenever a new time is selected. """ import matplotlib.pyplot as plt from matplotlib.gridspec import GridSpec from matplotlib.widgets import Slider from ..channels.layout import _merge_ch_data from ..evoked import Evoked _validate_type(evoked, Evoked, "evoked") _validate_type(colorbar, bool, "colorbar") evoked = evoked.copy() # make a copy, since we'll be picking ch_type = _get_plot_ch_type(evoked, ch_type) # time units / formatting time_unit, _ = _check_time_unit(time_unit, evoked.times) scaling_time = 1.0 if time_unit == "s" else 1e3 _validate_type(time_format, (None, str), "time_format") if time_format is None: time_format = "%0.3f s" if time_unit == "s" else "%01d ms" del time_unit # mask_params defaults mask_params = _handle_default("mask_params", mask_params) mask_params["markersize"] *= size / 2.0 mask_params["markeredgewidth"] *= size / 2.0 # setup various parameters, and prepare outlines ( picks, pos, merge_channels, names, ch_type, sphere, clip_origin, ) = _prepare_topomap_plot(evoked, ch_type, sphere=sphere) outlines = _make_head_outlines(sphere, pos, outlines, clip_origin) # check interactive axes_given = axes is not None interactive = isinstance(times, str) and times == "interactive" if interactive and axes_given: raise ValueError("User-provided axes not allowed when times='interactive'.") # units, scalings key = "grad" if ch_type.startswith("planar") else ch_type default_scaling = _handle_default("scalings", None)[key] scaling = _handle_default("scalings", scalings)[key] # if non-default scaling, fall back to "AU" if unit wasn't given by user key = "misc" if scaling != default_scaling else key unit = _handle_default("units", units)[key] # ch_names (required for NIRS) ch_names = names names = _prepare_sensor_names(names, show_names) # apply projections before picking. NOTE: the `if proj is True` # anti-pattern is needed here to exclude proj='interactive' _check_option("proj", proj, (True, False, "interactive", "reconstruct")) if proj is True and not evoked.proj: evoked.apply_proj() elif proj == "reconstruct": evoked._reconstruct_proj() data = evoked.data # remove compensation matrices (safe: only plotting & already made copy) with evoked.info._unlock(): evoked.info["comps"] = [] evoked = evoked._pick_drop_channels(picks, verbose=False) # determine which times to plot if isinstance(axes, plt.Axes): axes = [axes] n_peaks = len(axes) - int(colorbar) if axes_given else None times = _process_times(evoked, times, n_peaks) n_times = len(times) space = 1 / (2.0 * evoked.info["sfreq"]) if max(times) > max(evoked.times) + space or min(times) < min(evoked.times) - space: raise ValueError( f"Times should be between {evoked.times[0]:0.3} and " f"{evoked.times[-1]:0.3}." ) # create axes want_axes = n_times + int(colorbar) if interactive: height_ratios = [5, 1] nrows = 2 ncols = n_times width = size * want_axes height = size + max(0, 0.1 * (4 - size)) fig = figure_nobar(figsize=(width * 1.5, height * 1.5)) gs = GridSpec(nrows, ncols, height_ratios=height_ratios, figure=fig) axes = [] for ax_idx in range(n_times): axes.append(plt.subplot(gs[0, ax_idx])) elif axes is None: fig, axes, ncols, nrows = _prepare_trellis( n_times, ncols=ncols, nrows=nrows, size=size ) else: nrows, ncols = None, None # Deactivate ncols when axes were passed fig = axes[0].get_figure() # check: enough space for colorbar? if len(axes) != want_axes: cbar_err = " plus one for the colorbar" if colorbar else "" raise RuntimeError( f"You must provide {want_axes} axes (one for " f"each time{cbar_err}), got {len(axes)}." ) del want_axes # find first index that's >= (to rounding error) to each time point time_idx = [ np.where( _time_mask(evoked.times, tmin=t, tmax=None, sfreq=evoked.info["sfreq"]) )[0][0] for t in times ] # do averaging if requested avg_err = ( '"average" must be `None`, a positive number of seconds, or ' "an array-like object of the previous" ) averaged_times = [] if average is None: average = np.array([None] * n_times) data = data[np.ix_(picks, time_idx)] else: if _is_numeric(average): average = np.array([average] * n_times) elif np.array(average).ndim == 0: # It should be an array-like object raise TypeError(f"{avg_err}; got type: {type(average)}.") else: average = np.array(average) if len(average) != n_times: raise ValueError( f"You requested to plot topographic maps for {n_times} time " f"points, but provided {len(average)} periods for " f"averaging. The number of time points and averaging periods " f"must be equal." ) data_ = np.zeros((len(picks), len(time_idx))) for average_idx, (this_average, this_time, this_time_idx) in enumerate( zip(average, evoked.times[time_idx], time_idx) ): if (_is_numeric(this_average) and this_average <= 0) or ( not _is_numeric(this_average) and this_average is not None ): if len(average) == 1: msg = f"{avg_err}; got {this_average}" else: msg = f"{avg_err}; got {this_average} in {average}" raise ValueError(msg) if this_average is None: data_[:, average_idx] = data[picks][:, this_time_idx] averaged_times.append([this_time]) else: tmin_ = this_time - this_average / 2 tmax_ = this_time + this_average / 2 time_mask = (tmin_ < evoked.times) & (evoked.times < tmax_) data_[:, average_idx] = data[picks][:, time_mask].mean(-1) averaged_times.append(evoked.times[time_mask]) data = data_ # apply scalings and merge channels data *= scaling if merge_channels: data, ch_names = _merge_ch_data(data, ch_type, ch_names) if ch_type in _fnirs_types: merge_channels = False # apply mask if requested if mask is not None: mask = mask.astype(bool, copy=False) if ch_type == "grad": mask_ = ( mask[np.ix_(picks[::2], time_idx)] | mask[np.ix_(picks[1::2], time_idx)] ) else: # mag, eeg, planar1, planar2 mask_ = mask[np.ix_(picks, time_idx)] # set up colormap _vlim = [ _setup_vmin_vmax(data[:, i], *vlim, norm=merge_channels) for i in range(n_times) ] _vlim = (np.min(_vlim), np.max(_vlim)) cmap = _setup_cmap(cmap, n_axes=n_times, norm=_vlim[0] >= 0) # set up contours if not isinstance(contours, (list, np.ndarray)): _, contours = _set_contour_locator(*_vlim, contours) # prepare for main loop over times kwargs = dict( sensors=sensors, res=res, names=names, cmap=cmap[0], cnorm=cnorm, mask_params=mask_params, outlines=outlines, contours=contours, image_interp=image_interp, show=False, extrapolate=extrapolate, sphere=sphere, border=border, ch_type=ch_type, ) images, contours_ = [], [] # loop over times for average_idx, (time, this_average) in enumerate(zip(times, average)): tp, cn, interp = _plot_topomap( data[:, average_idx], pos, axes=axes[average_idx], mask=mask_[:, average_idx] if mask is not None else None, vmin=_vlim[0], vmax=_vlim[1], **kwargs, ) images.append(tp) if cn is not None: contours_.append(cn) if time_format != "": if this_average is None: axes_title = time_format % (time * scaling_time) else: tmin_ = averaged_times[average_idx][0] tmax_ = averaged_times[average_idx][-1] from_time = time_format % (tmin_ * scaling_time) from_time = from_time.split(" ")[0] # Remove unit to_time = time_format % (tmax_ * scaling_time) axes_title = f"{from_time} – {to_time}" del from_time, to_time, tmin_, tmax_ axes[average_idx].set_title(axes_title) if interactive: # Add a slider to the figure and start publishing and subscribing to time_change # events. kwargs.update(vlim=_vlim) axes.append(fig.add_subplot(gs[1])) slider = Slider( axes[-1], "Time", evoked.times[0], evoked.times[-1], valinit=times[0], valfmt="%1.2fs", ) slider.vline.remove() # remove initial point indicator func = _merge_ch_data if merge_channels else lambda x: x def _slider_changed(val): publish(fig, TimeChange(time=val)) slider.on_changed(_slider_changed) ts = np.tile(evoked.times, len(evoked.data)).reshape(evoked.data.shape) axes[-1].plot(ts, evoked.data, color="k") axes[-1].slider = slider subscribe( fig, "time_change", partial( _on_time_change, fig=fig, data=evoked.data, times=evoked.times, pos=pos, scaling=scaling, func=func, time_format=time_format, scaling_time=scaling_time, slider=slider, kwargs=kwargs, ), ) subscribe( fig, "colormap_range", partial(_on_colormap_range, kwargs=kwargs), ) if colorbar: if nrows is None or ncols is None: # axes were given by the user, so don't resize the colorbar cax = axes[-1] else: # use the default behavior cax = None cbar = fig.colorbar(images[-1], ax=axes, cax=cax, format=cbar_fmt, shrink=0.6) if unit is not None: cbar.ax.set_title(unit) if cn is not None: cbar.set_ticks(contours) cbar.ax.tick_params(labelsize=7) if cmap[1]: for im in images: im.axes.CB = DraggableColorbar( cbar, im, kind="evoked_topomap", ch_type=ch_type ) if proj == "interactive": _check_delayed_ssp(evoked) params = dict( evoked=evoked, fig=fig, projs=evoked.info["projs"], picks=picks, images=images, contours_=contours_, pos=pos, time_idx=time_idx, res=res, plot_update_proj_callback=_plot_update_evoked_topomap, merge_channels=merge_channels, scale=scaling, axes=axes[: len(axes) - bool(interactive)], contours=contours, interp=interp, extrapolate=extrapolate, ) _draw_proj_checkbox(None, params) # This is mostly for testing purposes, but it's also consistent with # raw.plot, so maybe not a bad thing in principle either from mne.viz._figure import BrowserParams fig.mne = BrowserParams(proj_checkboxes=params["proj_checks"]) plt_show(show, block=False) if axes_given: fig.canvas.draw() return fig def _resize_cbar(cax, n_fig_axes, size=1): """Resize colorbar.""" cpos = cax.get_position() if size <= 1: cpos.x0 = 1 - (0.7 + 0.1 / size) / n_fig_axes cpos.x1 = cpos.x0 + 0.1 / n_fig_axes cpos.y0 = 0.2 cpos.y1 = 0.7 cax.set_position(cpos) def _on_time_change( event, fig, data, times, pos, scaling, func, time_format, scaling_time, slider, kwargs, ): """Handle updating topomap to show a new time.""" idx = np.argmin(np.abs(times - event.time)) data = func(data[:, idx]).ravel() * scaling ax = fig.axes[0] ax.clear() im, _ = plot_topomap(data, pos, axes=ax, **kwargs) if hasattr(ax, "CB"): ax.CB.mappable = im _resize_cbar(ax.CB.cbar.ax, 2) if time_format is not None: ax.set_title(time_format % (event.time * scaling_time)) # Updating the slider will generate a new time_change event. To prevent an # infinite loop, only update the slider if the time has actually changed. if event.time != slider.val: slider.set_val(event.time) ax.figure.canvas.draw_idle() def _on_colormap_range(event, kwargs): """Handle updating colormap range.""" logger.debug(f"Updating colormap range to {event.fmin}, {event.fmax}") kwargs.update(vlim=(event.fmin, event.fmax), cmap=event.cmap) def _plot_topomap_multi_cbar( data, pos, ax, *, vlim, title, unit, cmap, outlines, colorbar, cbar_fmt, sphere, ch_type, sensors, names, mask, mask_params, contours, image_interp, extrapolate, border, res, size, cnorm, ): _hide_frame(ax) _vlim = ( np.min(data) if vlim[0] is None else vlim[0], np.max(data) if vlim[1] is None else vlim[1], ) # this definition of "norm" allows non-diverging colormap for cases # where min & vmax are both negative (e.g., when they are power in dB) signs = np.sign(_vlim) norm = len(set(signs)) == 1 or np.any(signs == 0) _cmap = _setup_cmap(cmap, norm=norm) if title is not None: ax.set_title(title, fontsize=10) im, _ = plot_topomap( data, pos, ch_type=ch_type, sensors=sensors, names=names, mask=mask, mask_params=mask_params, contours=contours, outlines=outlines, sphere=sphere, image_interp=image_interp, extrapolate=extrapolate, border=border, res=res, size=size, cmap=_cmap[0], vlim=_vlim, cnorm=cnorm, axes=ax, show=False, onselect=None, ) if colorbar: cbar, cax = _add_colorbar(ax, im, cmap, title=None, format_=cbar_fmt) cbar.set_ticks(_vlim) if unit is not None: cbar.ax.set_ylabel(unit, fontsize=8) cbar.ax.tick_params(labelsize=8) @legacy(alt="Epochs.compute_psd().plot_topomap()") @verbose def plot_epochs_psd_topomap( epochs, bands=None, tmin=None, tmax=None, proj=False, *, bandwidth=None, adaptive=False, low_bias=True, normalization="length", ch_type=None, normalize=False, agg_fun=None, dB=False, sensors=True, names=None, mask=None, mask_params=None, contours=0, outlines="head", sphere=None, image_interp=_INTERPOLATION_DEFAULT, extrapolate=_EXTRAPOLATE_DEFAULT, border=_BORDER_DEFAULT, res=64, size=1, cmap=None, vlim=(None, None), cnorm=None, colorbar=True, cbar_fmt="auto", units=None, axes=None, show=True, n_jobs=None, verbose=None, ): """Plot the topomap of the power spectral density across epochs. Parameters ---------- epochs : instance of Epochs The epochs object. %(bands_psd_topo)s %(tmin_tmax_psd)s %(proj_psd)s bandwidth : float The bandwidth of the multi taper windowing function in Hz. The default value is a window half-bandwidth of 4 Hz. adaptive : bool Use adaptive weights to combine the tapered spectra into PSD (slow, use n_jobs >> 1 to speed up computation). low_bias : bool Only use tapers with more than 90%% spectral concentration within bandwidth. %(normalization)s %(ch_type_topomap_psd)s %(normalize_psd_topo)s %(agg_fun_psd_topo)s %(dB_plot_topomap)s %(sensors_topomap)s %(names_topomap)s %(mask_evoked_topomap)s %(mask_params_topomap)s %(contours_topomap)s %(outlines_topomap)s %(sphere_topomap_auto)s %(image_interp_topomap)s %(extrapolate_topomap)s .. versionchanged:: 0.21 - The default was changed to ``'local'`` for MEG sensors. - ``'local'`` was changed to use a convex hull mask - ``'head'`` was changed to extrapolate out to the clipping circle. %(border_topomap)s .. versionadded:: 0.20 %(res_topomap)s %(size_topomap)s %(cmap_topomap)s %(vlim_plot_topomap_psd)s .. versionadded:: 0.21 %(cnorm)s .. versionadded:: 1.2 %(colorbar_topomap)s %(cbar_fmt_topomap_psd)s %(units_topomap)s %(axes_spectrum_plot_topomap)s %(show)s %(n_jobs)s %(verbose)s Returns ------- fig : instance of Figure Figure showing one scalp topography per frequency band. """ from ..channels import rename_channels from ..time_frequency import Spectrum init_kw, plot_kw = _split_psd_kwargs(plot_fun=Spectrum.plot_topomap) spectrum = epochs.compute_psd(**init_kw) plot_kw.setdefault("show_names", False) if names is not None: rename_channels( spectrum.info, dict(zip(spectrum.ch_names, names)), verbose=verbose ) plot_kw["show_names"] = True return spectrum.plot_topomap(**plot_kw) @fill_doc def plot_psds_topomap( psds, freqs, pos, *, bands=None, ch_type="eeg", normalize=False, agg_fun=None, dB=True, sensors=True, names=None, mask=None, mask_params=None, contours=0, outlines="head", sphere=None, image_interp=_INTERPOLATION_DEFAULT, extrapolate=_EXTRAPOLATE_DEFAULT, border=_BORDER_DEFAULT, res=64, size=1, cmap=None, vlim=(None, None), cnorm=None, colorbar=True, cbar_fmt="auto", unit=None, axes=None, show=True, ): """Plot spatial maps of PSDs. Parameters ---------- psds : array of float, shape (n_channels, n_freqs) Power spectral densities. freqs : array of float, shape (n_freqs,) Frequencies used to compute psds. %(pos_topomap_psd)s %(bands_psd_topo)s %(ch_type_topomap)s %(normalize_psd_topo)s %(agg_fun_psd_topo)s %(dB_plot_topomap)s %(sensors_topomap)s %(names_topomap)s %(mask_evoked_topomap)s %(mask_params_topomap)s %(contours_topomap)s %(outlines_topomap)s %(sphere_topomap_auto)s %(image_interp_topomap)s %(extrapolate_topomap)s .. versionchanged:: 0.21 - The default was changed to ``'local'`` for MEG sensors. - ``'local'`` was changed to use a convex hull mask - ``'head'`` was changed to extrapolate out to the clipping circle. %(border_topomap)s .. versionadded:: 0.20 %(res_topomap)s %(size_topomap)s %(cmap_topomap)s %(vlim_plot_topomap_psd)s .. versionadded:: 0.21 %(cnorm)s .. versionadded:: 1.2 %(colorbar_topomap)s %(cbar_fmt_topomap_psd)s unit : str | None Measurement unit to be displayed with the colorbar. If ``None``, no unit is displayed (only "power" or "dB" as appropriate). %(axes_spectrum_plot_topomap)s %(show)s Returns ------- fig : instance of matplotlib.figure.Figure Figure with a topomap subplot for each band. """ import matplotlib.pyplot as plt from matplotlib.axes import Axes # handle some defaults sphere = _check_sphere(sphere) if cbar_fmt == "auto": cbar_fmt = "%0.1f" if dB else "%0.3f" # make sure `bands` is a dict if bands is None: bands = { "Delta (0-4 Hz)": (0, 4), "Theta (4-8 Hz)": (4, 8), "Alpha (8-12 Hz)": (8, 12), "Beta (12-30 Hz)": (12, 30), "Gamma (30-45 Hz)": (30, 45), } elif not hasattr(bands, "keys"): # convert legacy list-of-tuple input to a dict bands = {band[-1]: band[:-1] for band in bands} logger.info( "converting legacy list-of-tuples input to a dict for the " "`bands` parameter" ) # upconvert single freqs to band upper/lower edges as needed bin_spacing = np.diff(freqs)[0] bin_edges = np.array([0, bin_spacing]) - bin_spacing / 2 for band, _edges in bands.items(): if not hasattr(_edges, "__len__"): _edges = (_edges,) if len(_edges) == 1: bands[band] = tuple(bin_edges + freqs[np.argmin(np.abs(freqs - _edges[0]))]) # normalize data (if requested) if normalize: psds /= psds.sum(axis=-1, keepdims=True) assert np.allclose(psds.sum(axis=-1), 1.0) # aggregate within bands if agg_fun is None: agg_fun = np.sum if normalize else np.mean freq_masks = list() for band, (fmin, fmax) in bands.items(): _mask = (fmin < freqs) & (freqs < fmax) # make sure no bands are empty if _mask.sum() == 0: raise RuntimeError(f'No frequencies in band "{band}" ({fmin}, {fmax})') freq_masks.append(_mask) band_data = [agg_fun(psds[:, _mask], axis=1) for _mask in freq_masks] if dB and not normalize: band_data = [10 * np.log10(_d) for _d in band_data] # handle vmin/vmax joint_vlim = vlim == "joint" if joint_vlim: vlim = (np.array(band_data).min(), np.array(band_data).max()) # unit label if unit is None: unit = "dB" if dB and not normalize else "power" else: _dB = dB and not normalize unit = _format_units_psd(unit, dB=_dB) # set up figure / axes n_axes = len(bands) user_passed_axes = axes is not None if user_passed_axes: if isinstance(axes, Axes): axes = [axes] _validate_if_list_of_axes(axes, n_axes) fig = axes[0].figure else: fig, axes = plt.subplots( 1, n_axes, figsize=(2 * n_axes, 1.5), layout="constrained" ) if n_axes == 1: axes = [axes] # loop over subplots/frequency bands for ax, _mask, _data, (title, (fmin, fmax)) in zip( axes, freq_masks, band_data, bands.items() ): colorbar = (not joint_vlim) or ax == axes[-1] _plot_topomap_multi_cbar( _data, pos, ax, title=title, vlim=vlim, cmap=cmap, outlines=outlines, colorbar=colorbar, unit=unit, cbar_fmt=cbar_fmt, sphere=sphere, ch_type=ch_type, sensors=sensors, names=names, mask=mask, mask_params=mask_params, contours=contours, image_interp=image_interp, extrapolate=extrapolate, border=border, res=res, size=size, cnorm=cnorm, ) if not user_passed_axes: fig.canvas.draw() plt_show(show) return fig @fill_doc def plot_layout(layout, picks=None, show_axes=False, show=True): """Plot the sensor positions. Parameters ---------- layout : None | Layout Layout instance specifying sensor positions. %(picks_layout)s show_axes : bool Show layout axes if True. Defaults to False. show : bool Show figure if True. Defaults to True. Returns ------- fig : instance of Figure Figure containing the sensor topography. Notes ----- .. versionadded:: 0.12.0 """ import matplotlib.pyplot as plt fig = plt.figure( figsize=(max(plt.rcParams["figure.figsize"]),) * 2, layout="constrained" ) ax = fig.add_subplot(111) ax.set(xticks=[], yticks=[], aspect="equal") outlines = dict(border=([0, 1, 1, 0, 0], [0, 0, 1, 1, 0])) _draw_outlines(ax, outlines) layout = layout.copy().pick(picks) for ii, (p, ch_id) in enumerate(zip(layout.pos, layout.names)): center_pos = np.array((p[0] + p[2] / 2.0, p[1] + p[3] / 2.0)) ax.annotate( ch_id, xy=center_pos, horizontalalignment="center", verticalalignment="center", size="x-small", ) if show_axes: x1, x2, y1, y2 = p[0], p[0] + p[2], p[1], p[1] + p[3] ax.plot([x1, x1, x2, x2, x1], [y1, y2, y2, y1, y1], color="k") ax.axis("off") plt_show(show) return fig def _onselect( eclick, erelease, tfr, pos, ch_type, itmin, itmax, ifmin, ifmax, cmap, fig, layout=None, ): """Handle drawing average tfr over channels called from topomap.""" import matplotlib.pyplot as plt from matplotlib.collections import PathCollection from ..channels.layout import _pair_grad_sensors ax = eclick.inaxes xmin = min(eclick.xdata, erelease.xdata) xmax = max(eclick.xdata, erelease.xdata) ymin = min(eclick.ydata, erelease.ydata) ymax = max(eclick.ydata, erelease.ydata) indices = ( (pos[:, 0] < xmax) & (pos[:, 0] > xmin) & (pos[:, 1] < ymax) & (pos[:, 1] > ymin) ) colors = ["r" if ii else "k" for ii in indices] indices = np.where(indices)[0] for collection in ax.collections: if isinstance(collection, PathCollection): # this is our "scatter" collection.set_color(colors) ax.figure.canvas.draw() if len(indices) == 0: return data = tfr.data if ch_type == "mag": picks = pick_types(tfr.info, meg=ch_type, ref_meg=False) data = np.mean(data[indices, ifmin:ifmax, itmin:itmax], axis=0) chs = [tfr.ch_names[picks[x]] for x in indices] elif ch_type == "grad": grads = _pair_grad_sensors(tfr.info, layout=layout, topomap_coords=False) idxs = list() for idx in indices: idxs.append(grads[idx * 2]) idxs.append(grads[idx * 2 + 1]) # pair of grads data = np.mean(data[idxs, ifmin:ifmax, itmin:itmax], axis=0) chs = [tfr.ch_names[x] for x in idxs] elif ch_type == "eeg": picks = pick_types(tfr.info, meg=False, eeg=True, ref_meg=False) data = np.mean(data[indices, ifmin:ifmax, itmin:itmax], axis=0) chs = [tfr.ch_names[picks[x]] for x in indices] logger.info("Averaging TFR over channels " + str(chs)) if len(fig) == 0: fig.append(figure_nobar()) if not plt.fignum_exists(fig[0].number): fig[0] = figure_nobar() ax = fig[0].add_subplot(111) itmax = len(tfr.times) - 1 if itmax is None else min(itmax, len(tfr.times) - 1) ifmax = len(tfr.freqs) - 1 if ifmax is None else min(ifmax, len(tfr.freqs) - 1) if itmin is None: itmin = 0 if ifmin is None: ifmin = 0 extent = ( tfr.times[itmin] * 1e3, tfr.times[itmax] * 1e3, tfr.freqs[ifmin], tfr.freqs[ifmax], ) title = f"Average over {len(chs)} {ch_type} channels." ax.set_title(title) ax.set_xlabel("Time (ms)") ax.set_ylabel("Frequency (Hz)") img = ax.imshow(data, extent=extent, aspect="auto", origin="lower", cmap=cmap) if len(fig[0].get_axes()) < 2: fig[0].get_axes()[1].cbar = fig[0].colorbar(mappable=img) else: fig[0].get_axes()[1].cbar.on_mappable_changed(mappable=img) fig[0].canvas.draw() plt.figure(fig[0].number) plt_show(True) def _prepare_topomap(pos, ax, check_nonzero=True): """Prepare the topomap axis and check positions. Hides axis frame and check that position information is present. """ _hide_frame(ax) if check_nonzero and not pos.any(): raise RuntimeError( "No position information found, cannot compute geometries for topomap." ) def _hide_frame(ax): """Hide axis frame for topomaps.""" ax.get_yticks() ax.xaxis.set_ticks([]) ax.yaxis.set_ticks([]) ax.set_frame_on(False) def _check_extrapolate(extrapolate, ch_type): _check_option("extrapolate", extrapolate, ("box", "local", "head", "auto")) if extrapolate == "auto": extrapolate = "local" if ch_type in _MEG_CH_TYPES_SPLIT else "head" return extrapolate @verbose def _init_anim( ax, ax_line, ax_cbar, params, merge_channels, sphere, ch_type, image_interp, extrapolate, verbose, ): """Initialize animated topomap.""" logger.info("Initializing animation...") data = params["data"] items = list() vmin = params["vmin"] if "vmin" in params else None vmax = params["vmax"] if "vmax" in params else None if params["butterfly"]: all_times = params["all_times"] for idx in range(len(data)): ax_line.plot(all_times, data[idx], color="k", lw=1) vmin, vmax = _setup_vmin_vmax(data, vmin, vmax) ax_line.set( yticks=np.around(np.linspace(vmin, vmax, 5), -1), xlim=all_times[[0, -1]] ) params["line"] = ax_line.axvline(all_times[0], color="r") items.append(params["line"]) if merge_channels: from mne.channels.layout import _merge_ch_data data, _ = _merge_ch_data(data, "grad", []) norm = True if np.min(data) > 0 else False cmap = "Reds" if norm else "RdBu_r" vmin, vmax = _setup_vmin_vmax(data, vmin, vmax, norm) outlines = _make_head_outlines(sphere, params["pos"], "head", params["clip_origin"]) _hide_frame(ax) extent, Xi, Yi, interp = _setup_interp( pos=params["pos"], res=64, image_interp=image_interp, extrapolate=extrapolate, outlines=outlines, border=0, ) patch_ = _get_patch(outlines, extrapolate, interp, ax) params["Zis"] = list() for frame in params["frames"]: params["Zis"].append(interp.set_values(data[:, frame])(Xi, Yi)) Zi = params["Zis"][0] zi_min = np.nanmin(params["Zis"]) zi_max = np.nanmax(params["Zis"]) cont_lims = np.linspace(zi_min, zi_max, 7, endpoint=False)[1:] params.update( { "vmin": vmin, "vmax": vmax, "Xi": Xi, "Yi": Yi, "Zi": Zi, "extent": extent, "cmap": cmap, "cont_lims": cont_lims, } ) # plot map and contour im = ax.imshow( Zi, cmap=cmap, vmin=vmin, vmax=vmax, origin="lower", aspect="equal", extent=extent, interpolation="bilinear", ) ax.autoscale(enable=True, tight=True) ax.figure.colorbar(im, cax=ax_cbar) cont = ax.contour(Xi, Yi, Zi, levels=cont_lims, colors="k", linewidths=1) im.set_clip_path(patch_) text = ax.text(0.55, 0.95, "", transform=ax.transAxes, va="center", ha="right") params["text"] = text items.append(im) items.append(text) cont_collections = _cont_collections(cont) for col in cont_collections: col.set_clip_path(patch_) outlines_ = _draw_outlines(ax, outlines) params.update({"patch": patch_, "outlines": outlines_}) return tuple(items) + cont_collections def _animate(frame, ax, ax_line, params): """Update animated topomap.""" if params["pause"]: frame = params["frame"] time_idx = params["frames"][frame] if params["time_unit"] == "ms": title = f"{params['times'][frame] * 1e3:6.0f} ms" else: title = f"{params['times'][frame]:6.3f} s" if params["blit"]: text = params["text"] else: ax.cla() # Clear old contours. text = ax.text(0.45, 1.15, "", transform=ax.transAxes) for k, (x, y) in params["outlines"].items(): if "mask" in k: continue ax.plot(x, y, color="k", linewidth=1, clip_on=False) _hide_frame(ax) text.set_text(title) vmin = params["vmin"] vmax = params["vmax"] Xi = params["Xi"] Yi = params["Yi"] Zi = params["Zis"][frame] extent = params["extent"] cmap = params["cmap"] patch = params["patch"] im = ax.imshow( Zi, cmap=cmap, vmin=vmin, vmax=vmax, origin="lower", aspect="equal", extent=extent, interpolation="bilinear", ) cont_lims = params["cont_lims"] with warnings.catch_warnings(record=True): warnings.simplefilter("ignore") cont = ax.contour(Xi, Yi, Zi, levels=cont_lims, colors="k", linewidths=1) im.set_clip_path(patch) cont_collections = _cont_collections(cont) for col in cont_collections: col.set_clip_path(patch) items = [im, text] if params["butterfly"]: all_times = params["all_times"] line = params["line"] line.remove() ylim = ax_line.get_ylim() params["line"] = ax_line.axvline(all_times[time_idx], color="r") ax_line.set_ylim(ylim) items.append(params["line"]) params["frame"] = frame return tuple(items) + cont_collections def _pause_anim(event, params): """Pause or continue the animation on mouse click.""" params["pause"] = not params["pause"] def _key_press(event, params): """Handle key presses for the animation.""" if event.key == "left": params["pause"] = True params["frame"] = max(params["frame"] - 1, 0) elif event.key == "right": params["pause"] = True params["frame"] = min(params["frame"] + 1, len(params["frames"]) - 1) def _topomap_animation( evoked, ch_type, times, frame_rate, butterfly, blit, show, time_unit, sphere, image_interp, extrapolate, *, vmin, vmax, verbose=None, ): """Make animation of evoked data as topomap timeseries. See mne.evoked.Evoked.animate_topomap. """ from matplotlib import animation from matplotlib import pyplot as plt if ch_type is None: ch_type = _get_plot_ch_type(evoked, ch_type) time_unit, _ = _check_time_unit(time_unit, evoked.times) if times is None: times = np.linspace(evoked.times[0], evoked.times[-1], 10) times = np.array(times) if times.ndim != 1: raise ValueError(f"times must be 1D, got {times.ndim} dimensions") if max(times) > evoked.times[-1] or min(times) < evoked.times[0]: raise ValueError("All times must be inside the evoked time series.") frames = [np.abs(evoked.times - time).argmin() for time in times] picks, pos, merge_channels, _, ch_type, sphere, clip_origin = _prepare_topomap_plot( evoked, ch_type, sphere=sphere ) data = evoked.data[picks, :] data *= _handle_default("scalings")[ch_type] norm = np.min(data) >= 0 vmin, vmax = _setup_vmin_vmax(data, vmin, vmax, norm) fig = plt.figure(figsize=(6, 5), layout="constrained") shape = (8, 12) colspan = shape[1] - 1 rowspan = shape[0] - bool(butterfly) ax = plt.subplot2grid(shape, (0, 0), rowspan=rowspan, colspan=colspan) if butterfly: ax_line = plt.subplot2grid(shape, (rowspan, 0), colspan=colspan) else: ax_line = None if isinstance(frames, Integral): frames = np.linspace(0, len(evoked.times) - 1, frames).astype(int) ax_cbar = plt.subplot2grid(shape, (0, colspan), rowspan=rowspan) ax_cbar.set_title(_handle_default("units")[ch_type], fontsize=10) extrapolate = _check_extrapolate(extrapolate, ch_type) params = dict( data=data, pos=pos, all_times=evoked.times, frame=0, frames=frames, butterfly=butterfly, blit=blit, pause=False, times=times, time_unit=time_unit, clip_origin=clip_origin, vmin=vmin, vmax=vmax, ) init_func = partial( _init_anim, ax=ax, ax_cbar=ax_cbar, ax_line=ax_line, params=params, merge_channels=merge_channels, sphere=sphere, ch_type=ch_type, image_interp=image_interp, extrapolate=extrapolate, verbose=verbose, ) animate_func = partial(_animate, ax=ax, ax_line=ax_line, params=params) pause_func = partial(_pause_anim, params=params) fig.canvas.mpl_connect("button_press_event", pause_func) key_press_func = partial(_key_press, params=params) fig.canvas.mpl_connect("key_press_event", key_press_func) if frame_rate is None: frame_rate = evoked.info["sfreq"] / 10.0 interval = 1000 / frame_rate # interval is in ms anim = animation.FuncAnimation( fig, animate_func, init_func=init_func, frames=len(frames), interval=interval, blit=blit, ) fig.mne_animation = anim # to make sure anim is not garbage collected plt_show(show, block=False) if "line" in params: # Finally remove the vertical line so it does not appear in saved fig. params["line"].remove() return fig, anim def _set_contour_locator(vmin, vmax, contours): """Set correct contour levels.""" locator = None if isinstance(contours, Integral) and contours > 0: from matplotlib import ticker # nbins = ticks - 1, since 2 of the ticks are vmin and vmax, the # correct number of bins is equal to contours + 1. locator = ticker.MaxNLocator(nbins=contours + 1) contours = locator.tick_values(vmin, vmax) return locator, contours def _plot_corrmap( data, subjs, indices, ch_type, ica, label, *, show, outlines, cmap, contours, sensors=False, template=False, sphere=None, image_interp=_INTERPOLATION_DEFAULT, extrapolate=_EXTRAPOLATE_DEFAULT, border=_BORDER_DEFAULT, show_names=False, ): """Customize ica.plot_components for corrmap.""" from ..channels.layout import _merge_ch_data if not template: title = "Detected components" if label is not None: title += " of type " + label else: title = "Supplied template" picks = list(range(len(data))) p = 20 if len(picks) > p: # plot components by sets of 20 n_components = len(picks) figs = [ _plot_corrmap( data[k : k + p], subjs[k : k + p], indices[k : k + p], ch_type, ica, label, show=show, outlines=outlines, cmap=cmap, contours=contours, sensors=sensors, image_interp=image_interp, extrapolate=extrapolate, border=border, show_names=show_names, ) for k in range(0, n_components, p) ] return figs elif np.isscalar(picks): picks = [picks] ( data_picks, pos, merge_channels, names, _, sphere, clip_origin, ) = _prepare_topomap_plot(ica, ch_type, sphere=sphere) names = _prepare_sensor_names(names, show_names) outlines = _make_head_outlines(sphere, pos, outlines, clip_origin) data = np.atleast_2d(data) data = data[:, data_picks] # prepare data for iteration fig, axes, _, _ = _prepare_trellis(len(picks), ncols=5) fig.suptitle(title) for ii, data_, ax, subject, idx in zip(picks, data, axes, subjs, indices): if template: ttl = f"Subj. {subject}, {ica._ica_names[idx]}" ax.set_title(ttl, fontsize=12) else: ax.set_title(f"Subj. {subject}") if merge_channels: data_, _ = _merge_ch_data(data_, ch_type, []) _vlim = _setup_vmin_vmax(data_, None, None) plot_topomap( data_.flatten(), pos, vlim=_vlim, names=names, res=64, axes=ax, cmap=cmap, outlines=outlines, contours=contours, show=False, sensors=sensors, image_interp=image_interp, extrapolate=extrapolate, border=border, ) _hide_frame(ax) fig.canvas.draw() plt_show(show) return fig def _trigradient(x, y, z): """Take gradients of z on a mesh.""" from matplotlib.tri import CubicTriInterpolator, Triangulation with warnings.catch_warnings(): # catch matplotlib warnings warnings.filterwarnings("ignore", category=DeprecationWarning) tri = Triangulation(x, y) tci = CubicTriInterpolator(tri, z) dx, dy = tci.gradient(tri.x, tri.y) return dx, dy @fill_doc def plot_arrowmap( data, info_from, info_to=None, scale=3e-10, vlim=(None, None), cnorm=None, cmap=None, sensors=True, res=64, axes=None, show_names=False, mask=None, mask_params=None, outlines="head", contours=6, image_interp=_INTERPOLATION_DEFAULT, show=True, onselect=None, extrapolate=_EXTRAPOLATE_DEFAULT, sphere=None, ): """Plot arrow map. Compute arrowmaps, based upon the Hosaka-Cohen transformation :footcite:`CohenHosaka1976`, these arrows represents an estimation of the current flow underneath the MEG sensors. They are a poor man's MNE. Since planar gradiometers takes gradients along latitude and longitude, they need to be projected to the flattened manifold span by magnetometer or radial gradiometers before taking the gradients in the 2D Cartesian coordinate system for visualization on the 2D topoplot. You can use the ``info_from`` and ``info_to`` parameters to interpolate from gradiometer data to magnetometer data. Parameters ---------- data : array, shape (n_channels,) The data values to plot. info_from : instance of Info The measurement info from data to interpolate from. info_to : instance of Info | None The measurement info to interpolate to. If None, it is assumed to be the same as info_from. scale : float, default 3e-10 To scale the arrows. %(vlim_plot_topomap)s .. versionadded:: 1.2 %(cnorm)s .. versionadded:: 1.2 %(cmap_topomap_simple)s %(sensors_topomap)s %(res_topomap)s %(axes_plot_topomap)s %(show_names_topomap)s If ``True``, a list of names must be provided (see ``names`` keyword). %(mask_topomap)s %(mask_params_topomap)s %(outlines_topomap)s %(contours_topomap)s %(image_interp_topomap)s %(show)s onselect : callable | None Handle for a function that is called when the user selects a set of channels by rectangle selection (matplotlib ``RectangleSelector``). If None interactive selection is disabled. Defaults to None. %(extrapolate_topomap)s .. versionadded:: 0.18 .. versionchanged:: 0.21 - The default was changed to ``'local'`` for MEG sensors. - ``'local'`` was changed to use a convex hull mask - ``'head'`` was changed to extrapolate out to the clipping circle. %(sphere_topomap_auto)s Returns ------- fig : matplotlib.figure.Figure The Figure of the plot. Notes ----- .. versionadded:: 0.17 References ---------- .. footbibliography:: """ from matplotlib import pyplot as plt from ..forward import _map_meg_or_eeg_channels sphere = _check_sphere(sphere, info_from) ch_type = _picks_by_type(info_from) if len(ch_type) > 1: raise ValueError( "Multiple channel types are not supported." "All channels must either be of type 'grad' " "or 'mag'." ) else: ch_type = ch_type[0][0] if ch_type not in ("mag", "grad"): raise ValueError( f"Channel type '{ch_type}' not supported. Supported channel " "types are 'mag' and 'grad'." ) if info_to is None and ch_type == "mag": info_to = info_from else: ch_type = _picks_by_type(info_to) if len(ch_type) > 1: raise ValueError("Multiple channel types are not supported.") else: ch_type = ch_type[0][0] if ch_type != "mag": raise ValueError(f"only 'mag' channel type is supported. Got {ch_type}") if info_to is not info_from: info_to = pick_info(info_to, pick_types(info_to, meg=True)) info_from = pick_info(info_from, pick_types(info_from, meg=True)) # XXX should probably support the "origin" argument mapping = _map_meg_or_eeg_channels( info_from, info_to, origin=(0.0, 0.0, 0.04), mode="accurate" ) data = np.dot(mapping, data) _, pos, _, _, _, sphere, clip_origin = _prepare_topomap_plot( info_to, "mag", sphere=sphere ) outlines = _make_head_outlines(sphere, pos, outlines, clip_origin) if axes is None: fig, axes = plt.subplots(layout="constrained") else: fig = axes.figure plot_topomap( data, pos, axes=axes, vlim=vlim, cmap=cmap, cnorm=cnorm, sensors=sensors, res=res, mask=mask, mask_params=mask_params, outlines=outlines, contours=contours, image_interp=image_interp, show=False, onselect=onselect, extrapolate=extrapolate, sphere=sphere, ch_type=ch_type, ) x, y = tuple(pos.T) dx, dy = _trigradient(x, y, data) dxx = dy.data dyy = -dx.data axes.quiver(x, y, dxx, dyy, scale=scale, color="k", lw=1) plt_show(show) return fig @fill_doc def plot_bridged_electrodes( info, bridged_idx, ed_matrix, title=None, topomap_args=None ): """Topoplot electrode distance matrix with bridged electrodes connected. Parameters ---------- %(info_not_none)s bridged_idx : list of tuple The indices of channels marked as bridged with each bridged pair stored as a tuple. Can be generated via :func:`mne.preprocessing.compute_bridged_electrodes`. ed_matrix : array of float, shape (n_channels, n_channels) The electrical distance matrix for each pair of EEG electrodes. Can be generated via :func:`mne.preprocessing.compute_bridged_electrodes`. title : str A title to add to the plot. topomap_args : dict | None Arguments to pass to :func:`mne.viz.plot_topomap`. Returns ------- fig : instance of matplotlib.figure.Figure The topoplot figure handle. See Also -------- mne.preprocessing.compute_bridged_electrodes """ import matplotlib.pyplot as plt from ..channels.layout import _find_topomap_coords if topomap_args is None: topomap_args = dict() else: topomap_args = topomap_args.copy() # don't change original picks = pick_types(info, eeg=True) topomap_args.setdefault("image_interp", "nearest") topomap_args.setdefault("cmap", "summer_r") topomap_args.setdefault("names", pick_info(info, picks).ch_names) topomap_args.setdefault("contours", False) sphere = topomap_args.get("sphere", _check_sphere(None)) if "axes" not in topomap_args: fig, ax = plt.subplots(layout="constrained") topomap_args["axes"] = ax else: fig = None # handle colorbar here instead of in plot_topomap colorbar = topomap_args.pop("colorbar", True) if ed_matrix.shape[1:] != (picks.size, picks.size): raise RuntimeError( f"Expected {(ed_matrix.shape[0], picks.size, picks.size)} " f"shaped `ed_matrix`, got {ed_matrix.shape}" ) # fill in lower triangular ed_matrix = ed_matrix.copy() tril_idx = np.tril_indices(picks.size) for epo_idx in range(ed_matrix.shape[0]): ed_matrix[epo_idx][tril_idx] = ed_matrix[epo_idx].T[tril_idx] elec_dists = np.median(np.nanmin(ed_matrix, axis=1), axis=0) im, cn = plot_topomap(elec_dists, pick_info(info, picks), **topomap_args) fig = im.figure if fig is None else fig # add bridged connections for idx0, idx1 in bridged_idx: pos = _find_topomap_coords(info, [idx0, idx1], sphere=sphere) im.axes.plot([pos[0, 0], pos[1, 0]], [pos[0, 1], pos[1, 1]], color="r") if title is not None: im.axes.set_title(title) if colorbar: cax = fig.colorbar(im, shrink=0.6) cax.set_label(r"Electrical Distance ($\mu$$V^2$)") return fig def plot_ch_adjacency(info, adjacency, ch_names, kind="2d", edit=False): """Plot channel adjacency. Parameters ---------- info : instance of Info Info object with channel locations. adjacency : array Array of channels x channels shape. Defines which channels are adjacent to each other. Note that if you edit adjacencies (via ``edit=True``), this array will be modified in place. ch_names : list of str Names of successive channels in the ``adjacency`` matrix. kind : str How to plot the adjacency. Can be either ``'3d'`` or ``'2d'``. edit : bool Whether to allow interactive editing of the adjacency matrix via clicking respective channel pairs. Once clicked, the channel is "activated" and turns green. Clicking on another channel adds or removes adjacency relation between the activated and newly clicked channel (depending on whether the channels are already adjacent or not); the newly clicked channel now becomes activated. Clicking on an activated channel deactivates it. Editing is currently only supported for ``kind='2d'``. Returns ------- fig : Figure The :class:`~matplotlib.figure.Figure` instance where the channel adjacency is plotted. See Also -------- mne.channels.get_builtin_ch_adjacencies mne.channels.read_ch_adjacency mne.channels.find_ch_adjacency Notes ----- .. versionadded:: 1.1 """ import matplotlib as mpl import matplotlib.pyplot as plt _validate_type(info, Info, "info") _validate_type(adjacency, (np.ndarray, csr_array), "adjacency") has_sparse = isinstance(adjacency, csr_array) if edit and kind == "3d": raise ValueError("Editing a 3d adjacency plot is not supported.") # select relevant channels sel = pick_channels(info.ch_names, ch_names, ordered=True) info = pick_info(info, sel) # make sure adjacency is correct size wrt to inst: n_channels = len(info.ch_names) if adjacency.shape[0] != n_channels: raise ValueError( "``adjacency`` must have the same number of rows " "as the number of channels in ``info``. Found " f"{adjacency.shape[0]} channels for ``adjacency`` and" f" {n_channels} for ``inst``." ) if kind == "3d": with plt.rc_context({"toolbar": "None"}): fig = plot_sensors(info, kind=kind, show=False) _set_3d_axes_equal(fig.axes[0]) elif kind == "2d": with plt.rc_context({"toolbar": "None"}): fig = plot_sensors(info, kind="topomap", show=False) fig.axes[0].axis("equal") path_collection = fig.axes[0].findobj(mpl.collections.PathCollection) path_collection[0].set_linewidths(0.0) if kind == "2d": path_collection[0].set_alpha(0.7) pos = path_collection[0].get_offsets() # make sure nodes are on top path_collection[0].set_zorder(10) # scale node size with number of connections n_connections = [np.sum(adjacency[[i]]) - 1 for i in range(adjacency.shape[0])] node_size = [max(x, 3) ** 2.5 for x in n_connections] path_collection[0].set_sizes(node_size) else: # plotting channel positions via mne.viz.plot_sensors(info) and using # the coordinates from info['chs'][ch_idx]['loc][:3] gives different # positions. Also .get_offsets gives 2d projections even for 3d points # so we use the private _offsets3d property... pos = path_collection[0]._offsets3d pos = np.stack([pos[0].data, pos[1].data, pos[2]], axis=1) ax = fig.axes[0] lines = dict() n_channels = adjacency.shape[0] for ch_idx in range(n_channels): # make sure we don't repeat channels row = adjacency[[ch_idx], ch_idx + 1 :] if has_sparse: ch_neighbours = row.nonzero()[1] else: ch_neighbours = np.where(row)[0] if len(ch_neighbours) == 0: continue ch_neighbours += ch_idx + 1 for ngb_idx in ch_neighbours: this_pos = pos[[ch_idx, ngb_idx], :] ch_pair = tuple([ch_idx, ngb_idx]) lines[ch_pair] = ax.plot(*this_pos.T, color=(0.55, 0.55, 0.55), lw=0.75)[0] if edit: # allow interactivity in 2d plots highlighted = dict() this_onpick = partial( _onpick_ch_adjacency, axes=ax, positions=pos, highlighted=highlighted, line_dict=lines, adjacency=adjacency, node_size=node_size, path_collection=path_collection, ) fig.canvas.mpl_connect("pick_event", this_onpick) return fig def _onpick_ch_adjacency( event, axes=None, positions=None, highlighted=None, line_dict=None, adjacency=None, node_size=None, path_collection=None, ): """Handle interactivity in plot_ch_adjacency.""" node_ind = event.ind[0] if node_ind in highlighted: # de-select node, change its color back to normal highlighted[node_ind].remove() del highlighted[node_ind] axes.figure.canvas.draw() else: # new node selected if len(highlighted) == 0: # no highlighted nodes yet size = max(node_size[node_ind] * 2, 100) # add current node dots = axes.scatter( *positions[node_ind, :].T, color="tab:green", s=size, zorder=15 ) highlighted[node_ind] = dots axes.figure.canvas.draw() # make sure it renders else: # one previously highlighted - add or remove line key = list(highlighted.keys())[0] both_nodes = [key, node_ind] both_nodes.sort() both_nodes = tuple(both_nodes) if both_nodes in line_dict.keys(): # remove line n_conn_change = -1 line_dict[both_nodes].remove() # remove line_dict entry del line_dict[both_nodes] # clear adjacency matrix entry _set_adjacency(adjacency, both_nodes, False) else: # add line n_conn_change = +1 selected_pos = positions[both_nodes, :] line = axes.plot(*selected_pos.T, color="tab:green")[0] # add line to line_dict line_dict[both_nodes] = line # modify adjacency matrix _set_adjacency(adjacency, both_nodes, True) # de-highlight previous highlighted[key].remove() del highlighted[key] # update node sizes n_connections = [ np.sum(adjacency[[idx]]) - 1 + n_conn_change for idx in both_nodes ] for idx, n_conn in zip(both_nodes, n_connections): node_size[idx] = max(n_conn, 3) ** 2.5 path_collection[0].set_sizes(node_size) # highlight new node size = max(node_size[node_ind] * 2, 100) dots = axes.scatter( *positions[node_ind, :].T, color="tab:green", s=size, zorder=15 ) highlighted[node_ind] = dots axes.figure.canvas.draw() def _set_adjacency(adjacency, both_nodes, value): """Set adjacency for given node pair, caching errors for sparse arrays.""" import warnings with warnings.catch_warnings(record=True): adjacency[both_nodes, both_nodes[::-1]] = value @fill_doc def plot_regression_weights( model, *, ch_type=None, sensors=True, show_names=False, mask=None, mask_params=None, contours=6, outlines="head", sphere=None, image_interp=_INTERPOLATION_DEFAULT, extrapolate=_EXTRAPOLATE_DEFAULT, border=_BORDER_DEFAULT, res=64, size=1, cmap=None, vlim=(None, None), cnorm=None, axes=None, colorbar=True, cbar_fmt="%1.1e", title=None, show=True, ): """Plot the regression weights of a fitted EOGRegression model. Parameters ---------- model : EOGRegression The fitted EOGRegression model whose weights will be plotted. %(ch_type_topomap)s %(sensors_topomap)s %(show_names_topomap)s %(mask_topomap)s %(mask_params_topomap)s %(contours_topomap)s %(outlines_topomap)s %(sphere_topomap_auto)s %(image_interp_topomap)s %(extrapolate_topomap)s .. versionchanged:: 0.21 - The default was changed to ``'local'`` for MEG sensors. - ``'local'`` was changed to use a convex hull mask - ``'head'`` was changed to extrapolate out to the clipping circle. %(border_topomap)s .. versionadded:: 0.20 %(res_topomap)s %(size_topomap)s %(cmap_topomap)s %(vlim_plot_topomap)s %(cnorm)s %(axes_evoked_plot_topomap)s %(colorbar_topomap)s %(cbar_fmt_topomap)s %(title_none)s %(show)s Returns ------- fig : instance of matplotlib.figure.Figure Figure with a topomap subplot for each channel type. Notes ----- .. versionadded:: 1.2 """ import matplotlib import matplotlib.pyplot as plt from ..channels.layout import _merge_ch_data sphere = _check_sphere(sphere) if ch_type is None: ch_types = model.info_.get_channel_types(unique=True, only_data_chs=True) else: ch_types = [ch_type] del ch_type nrows = model.coef_.shape[1] ncols = len(ch_types) axes_was_none = axes is None if axes_was_none: fig, axes = plt.subplots( nrows, ncols, squeeze=False, figsize=(ncols * 2, nrows * 1.5 + 1), layout="constrained", ) axes = axes.T.ravel() else: if isinstance(axes, matplotlib.axes.Axes): axes = [axes] fig = axes[0].get_figure() if len(axes) != nrows * ncols: raise ValueError( f"axes must be a list of {nrows * ncols} axes, got " f"length {len(axes)} ({axes})." ) axes = iter(axes) data_picks = _picks_to_idx(model.info_, model.picks, exclude=model.exclude) data_info = pick_info(model.info_, data_picks) artifact_ch_names = [ model.info_["chs"][idx]["ch_name"] for idx in _picks_to_idx(model.info_, model.picks_artifact) ] for ch_type in ch_types: ( data_picks, pos, merge_channels, names, ch_type, sphere, clip_origin, ) = _prepare_topomap_plot(data_info, ch_type=ch_type, sphere=sphere) outlines = _make_head_outlines( sphere, pos, outlines=outlines, clip_origin=clip_origin ) coef = model.coef_[data_picks] for data, ch_name in zip(coef.T, artifact_ch_names): if merge_channels: data, names = _merge_ch_data(data, ch_type, names) ax = next(axes) names = _prepare_sensor_names(data_info.ch_names, show_names) _plot_topomap_multi_cbar( data, pos, ax, title=f"{ch_type}/{ch_name}", vlim=vlim, cmap=cmap, outlines=outlines, colorbar=colorbar, unit="", cbar_fmt=cbar_fmt, sphere=sphere, ch_type=ch_type, sensors=sensors, names=names, mask=mask, mask_params=mask_params, contours=contours, image_interp=image_interp, extrapolate=extrapolate, border=border, res=res, size=size, cnorm=cnorm, ) if axes_was_none: fig.suptitle(title) plt_show(show) return fig