针对pulse-transit的工具
This commit is contained in:
214
dist/client/mne/viz/_dipole.py
vendored
Normal file
214
dist/client/mne/viz/_dipole.py
vendored
Normal file
@@ -0,0 +1,214 @@
|
||||
"""Dipole viz specific functions."""
|
||||
|
||||
# Authors: The MNE-Python contributors.
|
||||
# License: BSD-3-Clause
|
||||
# Copyright the MNE-Python contributors.
|
||||
|
||||
import os.path as op
|
||||
|
||||
import numpy as np
|
||||
from scipy.spatial import ConvexHull
|
||||
|
||||
from .._freesurfer import _estimate_talxfm_rigid, _get_head_surface
|
||||
from ..surface import read_surface
|
||||
from ..transforms import _get_trans, apply_trans, invert_transform
|
||||
from ..utils import _check_option, _validate_type, get_subjects_dir
|
||||
from .utils import _validate_if_list_of_axes, plt_show
|
||||
|
||||
|
||||
def _check_concat_dipoles(dipole):
|
||||
from ..dipole import Dipole, _concatenate_dipoles
|
||||
|
||||
if not isinstance(dipole, Dipole):
|
||||
dipole = _concatenate_dipoles(dipole)
|
||||
return dipole
|
||||
|
||||
|
||||
def _plot_dipole_mri_outlines(
|
||||
dipoles,
|
||||
*,
|
||||
subject,
|
||||
trans,
|
||||
ax,
|
||||
subjects_dir,
|
||||
color,
|
||||
scale,
|
||||
coord_frame,
|
||||
show,
|
||||
block,
|
||||
head_source,
|
||||
title,
|
||||
surf,
|
||||
width,
|
||||
):
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.collections import LineCollection, PatchCollection
|
||||
from matplotlib.patches import Circle
|
||||
|
||||
extra = 'when mode is "outlines"'
|
||||
trans = _get_trans(trans, fro="head", to="mri")[0]
|
||||
_check_option(
|
||||
"coord_frame", coord_frame, ["head", "mri", "mri_rotated"], extra=extra
|
||||
)
|
||||
_validate_type(surf, (str, None), "surf")
|
||||
_check_option("surf", surf, ("white", "pial", None))
|
||||
if ax is None:
|
||||
_, ax = plt.subplots(1, 3, figsize=(7, 2.5), squeeze=True, layout="constrained")
|
||||
_validate_if_list_of_axes(ax, 3, name="ax")
|
||||
dipoles = _check_concat_dipoles(dipoles)
|
||||
color = "r" if color is None else color
|
||||
scale = 0.03 if scale is None else scale
|
||||
width = 0.015 if width is None else width
|
||||
fig = ax[0].figure
|
||||
surfs = dict()
|
||||
hemis = ("lh", "rh")
|
||||
if surf is not None:
|
||||
for hemi in hemis:
|
||||
surfs[hemi] = read_surface(
|
||||
op.join(subjects_dir, subject, "surf", f"{hemi}.{surf}"),
|
||||
return_dict=True,
|
||||
)[2]
|
||||
surfs[hemi]["rr"] /= 1000.0
|
||||
subjects_dir = get_subjects_dir(subjects_dir)
|
||||
if subjects_dir is not None:
|
||||
subjects_dir = str(subjects_dir)
|
||||
surfs["head"] = _get_head_surface(head_source, subject, subjects_dir)
|
||||
del head_source
|
||||
mri_trans = head_trans = np.eye(4)
|
||||
if coord_frame in ("mri", "mri_rotated"):
|
||||
head_trans = trans["trans"]
|
||||
if coord_frame == "mri_rotated":
|
||||
rot = _estimate_talxfm_rigid(subject, subjects_dir)
|
||||
rot[:3, 3] = 0.0
|
||||
head_trans = rot @ head_trans
|
||||
mri_trans = rot @ mri_trans
|
||||
else:
|
||||
assert coord_frame == "head"
|
||||
mri_trans = invert_transform(trans)["trans"]
|
||||
for s in surfs.values():
|
||||
s["rr"] = 1000 * apply_trans(mri_trans, s["rr"])
|
||||
del mri_trans
|
||||
levels = dict()
|
||||
if surf is not None:
|
||||
use_rr = np.concatenate([surfs[key]["rr"] for key in hemis])
|
||||
else:
|
||||
use_rr = surfs["head"]["rr"]
|
||||
views = [("Axial", "XY"), ("Coronal", "XZ"), ("Sagittal", "YZ")]
|
||||
# axial: 25% up the Z axis
|
||||
axial = float(np.percentile(use_rr[:, 2], 20.0))
|
||||
coronal = float(np.percentile(use_rr[:, 1], 55.0))
|
||||
for key in hemis + ("head",):
|
||||
levels[key] = dict(Axial=axial, Coronal=coronal)
|
||||
if surf is not None:
|
||||
levels["rh"]["Sagittal"] = float(np.percentile(surfs["rh"]["rr"][:, 0], 50))
|
||||
levels["head"]["Sagittal"] = 0.0
|
||||
for ax_, (name, coords) in zip(ax, views):
|
||||
idx = list(map(dict(X=0, Y=1, Z=2).get, coords))
|
||||
miss = np.setdiff1d(np.arange(3), idx)[0]
|
||||
pos = 1000 * apply_trans(head_trans, dipoles.pos)
|
||||
ori = 1000 * apply_trans(head_trans, dipoles.ori, move=False)
|
||||
lims = dict()
|
||||
for ii, char in enumerate(coords):
|
||||
lim = surfs["head"]["rr"][:, idx[ii]]
|
||||
lim = np.array([lim.min(), lim.max()])
|
||||
lims[char] = lim
|
||||
ax_.quiver(
|
||||
pos[:, idx[0]],
|
||||
pos[:, idx[1]],
|
||||
scale * ori[:, idx[0]],
|
||||
scale * ori[:, idx[1]],
|
||||
color=color,
|
||||
pivot="middle",
|
||||
zorder=5,
|
||||
scale_units="xy",
|
||||
angles="xy",
|
||||
scale=1.0,
|
||||
width=width,
|
||||
minshaft=0.5,
|
||||
headwidth=2.5,
|
||||
headlength=2.5,
|
||||
headaxislength=2,
|
||||
)
|
||||
coll = PatchCollection(
|
||||
[
|
||||
Circle((x, y), radius=scale * 1000 * width * 6)
|
||||
for x, y in zip(pos[:, idx[0]], pos[:, idx[1]])
|
||||
],
|
||||
linewidths=0.0,
|
||||
facecolors=color,
|
||||
zorder=6,
|
||||
)
|
||||
for key, surf in surfs.items():
|
||||
try:
|
||||
level = levels[key][name]
|
||||
except KeyError:
|
||||
continue
|
||||
if key != "head":
|
||||
rrs = surf["rr"][:, idx]
|
||||
tris = ConvexHull(rrs).simplices
|
||||
segments = LineCollection(
|
||||
rrs[:, [0, 1]][tris],
|
||||
linewidths=1,
|
||||
linestyles="-",
|
||||
colors="k",
|
||||
zorder=3,
|
||||
alpha=0.25,
|
||||
)
|
||||
ax_.add_collection(segments)
|
||||
ax_.tricontour(
|
||||
surf["rr"][:, idx[0]],
|
||||
surf["rr"][:, idx[1]],
|
||||
surf["tris"],
|
||||
surf["rr"][:, miss],
|
||||
levels=[level],
|
||||
colors="k",
|
||||
linewidths=1.0,
|
||||
linestyles=["-"],
|
||||
zorder=4,
|
||||
alpha=0.5,
|
||||
)
|
||||
# TODO: this breaks the PatchCollection in MPL
|
||||
# for coll in h.collections:
|
||||
# coll.set_clip_on(False)
|
||||
ax_.add_collection(coll)
|
||||
ax_.set(
|
||||
title=name,
|
||||
xlim=lims[coords[0]],
|
||||
ylim=lims[coords[1]],
|
||||
xlabel=coords[0] + " (mm)",
|
||||
ylabel=coords[1] + " (mm)",
|
||||
)
|
||||
for spine in ax_.spines.values():
|
||||
spine.set_visible(False)
|
||||
ax_.grid(True, ls=":", zorder=2)
|
||||
ax_.set_aspect("equal")
|
||||
|
||||
if title is not None:
|
||||
fig.suptitle(title)
|
||||
plt_show(show, block=block)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def _plot_dipole_3d(dipoles, *, coord_frame, color, fig, trans, scale, mode):
|
||||
from .backends.renderer import _get_renderer
|
||||
|
||||
_check_option("coord_frame", coord_frame, ("head", "mri"))
|
||||
color = "r" if color is None else color
|
||||
scale = 0.005 if scale is None else scale
|
||||
renderer = _get_renderer(fig=fig, size=(600, 600))
|
||||
pos = dipoles.pos
|
||||
ori = dipoles.ori
|
||||
if coord_frame != "head":
|
||||
trans = _get_trans(trans, fro="head", to=coord_frame)[0]
|
||||
pos = apply_trans(trans, pos)
|
||||
ori = apply_trans(trans, ori)
|
||||
|
||||
renderer.sphere(center=pos, color=color, scale=scale)
|
||||
if mode == "arrow":
|
||||
x, y, z = pos.T
|
||||
u, v, w = ori.T
|
||||
renderer.quiver3d(x, y, z, u, v, w, scale=3 * scale, color=color, mode="arrow")
|
||||
renderer.show()
|
||||
fig = renderer.scene()
|
||||
return fig
|
||||
Reference in New Issue
Block a user