215 lines
6.9 KiB
Python
215 lines
6.9 KiB
Python
"""Dipole viz specific functions."""
|
|
|
|
# Authors: The MNE-Python contributors.
|
|
# License: BSD-3-Clause
|
|
# Copyright the MNE-Python contributors.
|
|
|
|
import os.path as op
|
|
|
|
import numpy as np
|
|
from scipy.spatial import ConvexHull
|
|
|
|
from .._freesurfer import _estimate_talxfm_rigid, _get_head_surface
|
|
from ..surface import read_surface
|
|
from ..transforms import _get_trans, apply_trans, invert_transform
|
|
from ..utils import _check_option, _validate_type, get_subjects_dir
|
|
from .utils import _validate_if_list_of_axes, plt_show
|
|
|
|
|
|
def _check_concat_dipoles(dipole):
|
|
from ..dipole import Dipole, _concatenate_dipoles
|
|
|
|
if not isinstance(dipole, Dipole):
|
|
dipole = _concatenate_dipoles(dipole)
|
|
return dipole
|
|
|
|
|
|
def _plot_dipole_mri_outlines(
|
|
dipoles,
|
|
*,
|
|
subject,
|
|
trans,
|
|
ax,
|
|
subjects_dir,
|
|
color,
|
|
scale,
|
|
coord_frame,
|
|
show,
|
|
block,
|
|
head_source,
|
|
title,
|
|
surf,
|
|
width,
|
|
):
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.collections import LineCollection, PatchCollection
|
|
from matplotlib.patches import Circle
|
|
|
|
extra = 'when mode is "outlines"'
|
|
trans = _get_trans(trans, fro="head", to="mri")[0]
|
|
_check_option(
|
|
"coord_frame", coord_frame, ["head", "mri", "mri_rotated"], extra=extra
|
|
)
|
|
_validate_type(surf, (str, None), "surf")
|
|
_check_option("surf", surf, ("white", "pial", None))
|
|
if ax is None:
|
|
_, ax = plt.subplots(1, 3, figsize=(7, 2.5), squeeze=True, layout="constrained")
|
|
_validate_if_list_of_axes(ax, 3, name="ax")
|
|
dipoles = _check_concat_dipoles(dipoles)
|
|
color = "r" if color is None else color
|
|
scale = 0.03 if scale is None else scale
|
|
width = 0.015 if width is None else width
|
|
fig = ax[0].figure
|
|
surfs = dict()
|
|
hemis = ("lh", "rh")
|
|
if surf is not None:
|
|
for hemi in hemis:
|
|
surfs[hemi] = read_surface(
|
|
op.join(subjects_dir, subject, "surf", f"{hemi}.{surf}"),
|
|
return_dict=True,
|
|
)[2]
|
|
surfs[hemi]["rr"] /= 1000.0
|
|
subjects_dir = get_subjects_dir(subjects_dir)
|
|
if subjects_dir is not None:
|
|
subjects_dir = str(subjects_dir)
|
|
surfs["head"] = _get_head_surface(head_source, subject, subjects_dir)
|
|
del head_source
|
|
mri_trans = head_trans = np.eye(4)
|
|
if coord_frame in ("mri", "mri_rotated"):
|
|
head_trans = trans["trans"]
|
|
if coord_frame == "mri_rotated":
|
|
rot = _estimate_talxfm_rigid(subject, subjects_dir)
|
|
rot[:3, 3] = 0.0
|
|
head_trans = rot @ head_trans
|
|
mri_trans = rot @ mri_trans
|
|
else:
|
|
assert coord_frame == "head"
|
|
mri_trans = invert_transform(trans)["trans"]
|
|
for s in surfs.values():
|
|
s["rr"] = 1000 * apply_trans(mri_trans, s["rr"])
|
|
del mri_trans
|
|
levels = dict()
|
|
if surf is not None:
|
|
use_rr = np.concatenate([surfs[key]["rr"] for key in hemis])
|
|
else:
|
|
use_rr = surfs["head"]["rr"]
|
|
views = [("Axial", "XY"), ("Coronal", "XZ"), ("Sagittal", "YZ")]
|
|
# axial: 25% up the Z axis
|
|
axial = float(np.percentile(use_rr[:, 2], 20.0))
|
|
coronal = float(np.percentile(use_rr[:, 1], 55.0))
|
|
for key in hemis + ("head",):
|
|
levels[key] = dict(Axial=axial, Coronal=coronal)
|
|
if surf is not None:
|
|
levels["rh"]["Sagittal"] = float(np.percentile(surfs["rh"]["rr"][:, 0], 50))
|
|
levels["head"]["Sagittal"] = 0.0
|
|
for ax_, (name, coords) in zip(ax, views):
|
|
idx = list(map(dict(X=0, Y=1, Z=2).get, coords))
|
|
miss = np.setdiff1d(np.arange(3), idx)[0]
|
|
pos = 1000 * apply_trans(head_trans, dipoles.pos)
|
|
ori = 1000 * apply_trans(head_trans, dipoles.ori, move=False)
|
|
lims = dict()
|
|
for ii, char in enumerate(coords):
|
|
lim = surfs["head"]["rr"][:, idx[ii]]
|
|
lim = np.array([lim.min(), lim.max()])
|
|
lims[char] = lim
|
|
ax_.quiver(
|
|
pos[:, idx[0]],
|
|
pos[:, idx[1]],
|
|
scale * ori[:, idx[0]],
|
|
scale * ori[:, idx[1]],
|
|
color=color,
|
|
pivot="middle",
|
|
zorder=5,
|
|
scale_units="xy",
|
|
angles="xy",
|
|
scale=1.0,
|
|
width=width,
|
|
minshaft=0.5,
|
|
headwidth=2.5,
|
|
headlength=2.5,
|
|
headaxislength=2,
|
|
)
|
|
coll = PatchCollection(
|
|
[
|
|
Circle((x, y), radius=scale * 1000 * width * 6)
|
|
for x, y in zip(pos[:, idx[0]], pos[:, idx[1]])
|
|
],
|
|
linewidths=0.0,
|
|
facecolors=color,
|
|
zorder=6,
|
|
)
|
|
for key, surf in surfs.items():
|
|
try:
|
|
level = levels[key][name]
|
|
except KeyError:
|
|
continue
|
|
if key != "head":
|
|
rrs = surf["rr"][:, idx]
|
|
tris = ConvexHull(rrs).simplices
|
|
segments = LineCollection(
|
|
rrs[:, [0, 1]][tris],
|
|
linewidths=1,
|
|
linestyles="-",
|
|
colors="k",
|
|
zorder=3,
|
|
alpha=0.25,
|
|
)
|
|
ax_.add_collection(segments)
|
|
ax_.tricontour(
|
|
surf["rr"][:, idx[0]],
|
|
surf["rr"][:, idx[1]],
|
|
surf["tris"],
|
|
surf["rr"][:, miss],
|
|
levels=[level],
|
|
colors="k",
|
|
linewidths=1.0,
|
|
linestyles=["-"],
|
|
zorder=4,
|
|
alpha=0.5,
|
|
)
|
|
# TODO: this breaks the PatchCollection in MPL
|
|
# for coll in h.collections:
|
|
# coll.set_clip_on(False)
|
|
ax_.add_collection(coll)
|
|
ax_.set(
|
|
title=name,
|
|
xlim=lims[coords[0]],
|
|
ylim=lims[coords[1]],
|
|
xlabel=coords[0] + " (mm)",
|
|
ylabel=coords[1] + " (mm)",
|
|
)
|
|
for spine in ax_.spines.values():
|
|
spine.set_visible(False)
|
|
ax_.grid(True, ls=":", zorder=2)
|
|
ax_.set_aspect("equal")
|
|
|
|
if title is not None:
|
|
fig.suptitle(title)
|
|
plt_show(show, block=block)
|
|
|
|
return fig
|
|
|
|
|
|
def _plot_dipole_3d(dipoles, *, coord_frame, color, fig, trans, scale, mode):
|
|
from .backends.renderer import _get_renderer
|
|
|
|
_check_option("coord_frame", coord_frame, ("head", "mri"))
|
|
color = "r" if color is None else color
|
|
scale = 0.005 if scale is None else scale
|
|
renderer = _get_renderer(fig=fig, size=(600, 600))
|
|
pos = dipoles.pos
|
|
ori = dipoles.ori
|
|
if coord_frame != "head":
|
|
trans = _get_trans(trans, fro="head", to=coord_frame)[0]
|
|
pos = apply_trans(trans, pos)
|
|
ori = apply_trans(trans, ori)
|
|
|
|
renderer.sphere(center=pos, color=color, scale=scale)
|
|
if mode == "arrow":
|
|
x, y, z = pos.T
|
|
u, v, w = ori.T
|
|
renderer.quiver3d(x, y, z, u, v, w, scale=3 * scale, color=color, mode="arrow")
|
|
renderer.show()
|
|
fig = renderer.scene()
|
|
return fig
|