"""Dipole viz specific functions.""" # Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. import os.path as op import numpy as np from scipy.spatial import ConvexHull from .._freesurfer import _estimate_talxfm_rigid, _get_head_surface from ..surface import read_surface from ..transforms import _get_trans, apply_trans, invert_transform from ..utils import _check_option, _validate_type, get_subjects_dir from .utils import _validate_if_list_of_axes, plt_show def _check_concat_dipoles(dipole): from ..dipole import Dipole, _concatenate_dipoles if not isinstance(dipole, Dipole): dipole = _concatenate_dipoles(dipole) return dipole def _plot_dipole_mri_outlines( dipoles, *, subject, trans, ax, subjects_dir, color, scale, coord_frame, show, block, head_source, title, surf, width, ): import matplotlib.pyplot as plt from matplotlib.collections import LineCollection, PatchCollection from matplotlib.patches import Circle extra = 'when mode is "outlines"' trans = _get_trans(trans, fro="head", to="mri")[0] _check_option( "coord_frame", coord_frame, ["head", "mri", "mri_rotated"], extra=extra ) _validate_type(surf, (str, None), "surf") _check_option("surf", surf, ("white", "pial", None)) if ax is None: _, ax = plt.subplots(1, 3, figsize=(7, 2.5), squeeze=True, layout="constrained") _validate_if_list_of_axes(ax, 3, name="ax") dipoles = _check_concat_dipoles(dipoles) color = "r" if color is None else color scale = 0.03 if scale is None else scale width = 0.015 if width is None else width fig = ax[0].figure surfs = dict() hemis = ("lh", "rh") if surf is not None: for hemi in hemis: surfs[hemi] = read_surface( op.join(subjects_dir, subject, "surf", f"{hemi}.{surf}"), return_dict=True, )[2] surfs[hemi]["rr"] /= 1000.0 subjects_dir = get_subjects_dir(subjects_dir) if subjects_dir is not None: subjects_dir = str(subjects_dir) surfs["head"] = _get_head_surface(head_source, subject, subjects_dir) del head_source mri_trans = head_trans = np.eye(4) if coord_frame in ("mri", "mri_rotated"): head_trans = trans["trans"] if coord_frame == "mri_rotated": rot = _estimate_talxfm_rigid(subject, subjects_dir) rot[:3, 3] = 0.0 head_trans = rot @ head_trans mri_trans = rot @ mri_trans else: assert coord_frame == "head" mri_trans = invert_transform(trans)["trans"] for s in surfs.values(): s["rr"] = 1000 * apply_trans(mri_trans, s["rr"]) del mri_trans levels = dict() if surf is not None: use_rr = np.concatenate([surfs[key]["rr"] for key in hemis]) else: use_rr = surfs["head"]["rr"] views = [("Axial", "XY"), ("Coronal", "XZ"), ("Sagittal", "YZ")] # axial: 25% up the Z axis axial = float(np.percentile(use_rr[:, 2], 20.0)) coronal = float(np.percentile(use_rr[:, 1], 55.0)) for key in hemis + ("head",): levels[key] = dict(Axial=axial, Coronal=coronal) if surf is not None: levels["rh"]["Sagittal"] = float(np.percentile(surfs["rh"]["rr"][:, 0], 50)) levels["head"]["Sagittal"] = 0.0 for ax_, (name, coords) in zip(ax, views): idx = list(map(dict(X=0, Y=1, Z=2).get, coords)) miss = np.setdiff1d(np.arange(3), idx)[0] pos = 1000 * apply_trans(head_trans, dipoles.pos) ori = 1000 * apply_trans(head_trans, dipoles.ori, move=False) lims = dict() for ii, char in enumerate(coords): lim = surfs["head"]["rr"][:, idx[ii]] lim = np.array([lim.min(), lim.max()]) lims[char] = lim ax_.quiver( pos[:, idx[0]], pos[:, idx[1]], scale * ori[:, idx[0]], scale * ori[:, idx[1]], color=color, pivot="middle", zorder=5, scale_units="xy", angles="xy", scale=1.0, width=width, minshaft=0.5, headwidth=2.5, headlength=2.5, headaxislength=2, ) coll = PatchCollection( [ Circle((x, y), radius=scale * 1000 * width * 6) for x, y in zip(pos[:, idx[0]], pos[:, idx[1]]) ], linewidths=0.0, facecolors=color, zorder=6, ) for key, surf in surfs.items(): try: level = levels[key][name] except KeyError: continue if key != "head": rrs = surf["rr"][:, idx] tris = ConvexHull(rrs).simplices segments = LineCollection( rrs[:, [0, 1]][tris], linewidths=1, linestyles="-", colors="k", zorder=3, alpha=0.25, ) ax_.add_collection(segments) ax_.tricontour( surf["rr"][:, idx[0]], surf["rr"][:, idx[1]], surf["tris"], surf["rr"][:, miss], levels=[level], colors="k", linewidths=1.0, linestyles=["-"], zorder=4, alpha=0.5, ) # TODO: this breaks the PatchCollection in MPL # for coll in h.collections: # coll.set_clip_on(False) ax_.add_collection(coll) ax_.set( title=name, xlim=lims[coords[0]], ylim=lims[coords[1]], xlabel=coords[0] + " (mm)", ylabel=coords[1] + " (mm)", ) for spine in ax_.spines.values(): spine.set_visible(False) ax_.grid(True, ls=":", zorder=2) ax_.set_aspect("equal") if title is not None: fig.suptitle(title) plt_show(show, block=block) return fig def _plot_dipole_3d(dipoles, *, coord_frame, color, fig, trans, scale, mode): from .backends.renderer import _get_renderer _check_option("coord_frame", coord_frame, ("head", "mri")) color = "r" if color is None else color scale = 0.005 if scale is None else scale renderer = _get_renderer(fig=fig, size=(600, 600)) pos = dipoles.pos ori = dipoles.ori if coord_frame != "head": trans = _get_trans(trans, fro="head", to=coord_frame)[0] pos = apply_trans(trans, pos) ori = apply_trans(trans, ori) renderer.sphere(center=pos, color=color, scale=scale) if mode == "arrow": x, y, z = pos.T u, v, w = ori.T renderer.quiver3d(x, y, z, u, v, w, scale=3 * scale, color=color, mode="arrow") renderer.show() fig = renderer.scene() return fig