470 lines
15 KiB
Python
470 lines
15 KiB
Python
"""Functions to plot on circle as for connectivity."""
|
|
|
|
# Authors: The MNE-Python contributors.
|
|
# License: BSD-3-Clause
|
|
# Copyright the MNE-Python contributors.
|
|
|
|
from functools import partial
|
|
from itertools import cycle
|
|
|
|
import numpy as np
|
|
|
|
from ..utils import _validate_type
|
|
from .utils import _get_cmap, plt_show
|
|
|
|
|
|
def circular_layout(
|
|
node_names,
|
|
node_order,
|
|
start_pos=90,
|
|
start_between=True,
|
|
group_boundaries=None,
|
|
group_sep=10,
|
|
):
|
|
"""Create layout arranging nodes on a circle.
|
|
|
|
Parameters
|
|
----------
|
|
node_names : list of str
|
|
Node names.
|
|
node_order : list of str
|
|
List with node names defining the order in which the nodes are
|
|
arranged. Must have the elements as node_names but the order can be
|
|
different. The nodes are arranged clockwise starting at "start_pos"
|
|
degrees.
|
|
start_pos : float
|
|
Angle in degrees that defines where the first node is plotted.
|
|
start_between : bool
|
|
If True, the layout starts with the position between the nodes. This is
|
|
the same as adding "180. / len(node_names)" to start_pos.
|
|
group_boundaries : None | array-like
|
|
List of of boundaries between groups at which point a "group_sep" will
|
|
be inserted. E.g. "[0, len(node_names) / 2]" will create two groups.
|
|
group_sep : float
|
|
Group separation angle in degrees. See "group_boundaries".
|
|
|
|
Returns
|
|
-------
|
|
node_angles : array, shape=(n_node_names,)
|
|
Node angles in degrees.
|
|
"""
|
|
n_nodes = len(node_names)
|
|
|
|
if len(node_order) != n_nodes:
|
|
raise ValueError("node_order has to be the same length as node_names")
|
|
|
|
if group_boundaries is not None:
|
|
boundaries = np.array(group_boundaries, dtype=np.int64)
|
|
if np.any(boundaries >= n_nodes) or np.any(boundaries < 0):
|
|
raise ValueError('"group_boundaries" has to be between 0 and n_nodes - 1.')
|
|
if len(boundaries) > 1 and np.any(np.diff(boundaries) <= 0):
|
|
raise ValueError('"group_boundaries" must have non-decreasing values.')
|
|
n_group_sep = len(group_boundaries)
|
|
else:
|
|
n_group_sep = 0
|
|
boundaries = None
|
|
|
|
# convert it to a list with indices
|
|
node_order = [node_order.index(name) for name in node_names]
|
|
node_order = np.array(node_order)
|
|
if len(np.unique(node_order)) != n_nodes:
|
|
raise ValueError("node_order has repeated entries")
|
|
|
|
node_sep = (360.0 - n_group_sep * group_sep) / n_nodes
|
|
|
|
if start_between:
|
|
start_pos += node_sep / 2
|
|
|
|
if boundaries is not None and boundaries[0] == 0:
|
|
# special case when a group separator is at the start
|
|
start_pos += group_sep / 2
|
|
boundaries = boundaries[1:] if n_group_sep > 1 else None
|
|
|
|
node_angles = np.ones(n_nodes, dtype=np.float64) * node_sep
|
|
node_angles[0] = start_pos
|
|
if boundaries is not None:
|
|
node_angles[boundaries] += group_sep
|
|
|
|
node_angles = np.cumsum(node_angles)[node_order]
|
|
|
|
return node_angles
|
|
|
|
|
|
def _plot_connectivity_circle_onpick(
|
|
event,
|
|
fig=None,
|
|
ax=None,
|
|
indices=None,
|
|
n_nodes=0,
|
|
node_angles=None,
|
|
ylim=(9, 10),
|
|
):
|
|
"""Isolate connections around a single node when user left clicks a node.
|
|
|
|
On right click, resets all connections.
|
|
"""
|
|
if event.inaxes != ax:
|
|
return
|
|
|
|
if event.button == 1: # left click
|
|
# click must be near node radius
|
|
if not ylim[0] <= event.ydata <= ylim[1]:
|
|
return
|
|
|
|
# all angles in range [0, 2*pi]
|
|
node_angles = node_angles % (np.pi * 2)
|
|
node = np.argmin(np.abs(event.xdata - node_angles))
|
|
|
|
patches = event.inaxes.patches
|
|
for ii, (x, y) in enumerate(zip(indices[0], indices[1])):
|
|
patches[ii].set_visible(node in [x, y])
|
|
fig.canvas.draw()
|
|
elif event.button == 3: # right click
|
|
patches = event.inaxes.patches
|
|
for ii in range(np.size(indices, axis=1)):
|
|
patches[ii].set_visible(True)
|
|
fig.canvas.draw()
|
|
|
|
|
|
def _plot_connectivity_circle(
|
|
con,
|
|
node_names,
|
|
indices=None,
|
|
n_lines=None,
|
|
node_angles=None,
|
|
node_width=None,
|
|
node_height=None,
|
|
node_colors=None,
|
|
facecolor="black",
|
|
textcolor="white",
|
|
node_edgecolor="black",
|
|
linewidth=1.5,
|
|
colormap="hot",
|
|
vmin=None,
|
|
vmax=None,
|
|
colorbar=True,
|
|
title=None,
|
|
colorbar_size=None,
|
|
colorbar_pos=None,
|
|
fontsize_title=12,
|
|
fontsize_names=8,
|
|
fontsize_colorbar=8,
|
|
padding=6.0,
|
|
ax=None,
|
|
interactive=True,
|
|
node_linewidth=2.0,
|
|
show=True,
|
|
):
|
|
import matplotlib.patches as m_patches
|
|
import matplotlib.path as m_path
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.projections.polar import PolarAxes
|
|
|
|
_validate_type(ax, (None, PolarAxes))
|
|
|
|
n_nodes = len(node_names)
|
|
|
|
if node_angles is not None:
|
|
if len(node_angles) != n_nodes:
|
|
raise ValueError("node_angles has to be the same length as node_names")
|
|
# convert it to radians
|
|
node_angles = node_angles * np.pi / 180
|
|
else:
|
|
# uniform layout on unit circle
|
|
node_angles = np.linspace(0, 2 * np.pi, n_nodes, endpoint=False)
|
|
|
|
if node_width is None:
|
|
# widths correspond to the minimum angle between two nodes
|
|
dist_mat = node_angles[None, :] - node_angles[:, None]
|
|
dist_mat[np.diag_indices(n_nodes)] = 1e9
|
|
node_width = np.min(np.abs(dist_mat))
|
|
else:
|
|
node_width = node_width * np.pi / 180
|
|
|
|
if node_height is None:
|
|
node_height = 1.0
|
|
|
|
if node_colors is not None:
|
|
if len(node_colors) < n_nodes:
|
|
node_colors = cycle(node_colors)
|
|
else:
|
|
# assign colors using colormap
|
|
try:
|
|
spectral = plt.cm.spectral
|
|
except AttributeError:
|
|
spectral = plt.cm.Spectral
|
|
node_colors = [spectral(i / float(n_nodes)) for i in range(n_nodes)]
|
|
|
|
# handle 1D and 2D connectivity information
|
|
if con.ndim == 1:
|
|
if indices is None:
|
|
raise ValueError("indices has to be provided if con.ndim == 1")
|
|
elif con.ndim == 2:
|
|
if con.shape[0] != n_nodes or con.shape[1] != n_nodes:
|
|
raise ValueError("con has to be 1D or a square matrix")
|
|
# we use the lower-triangular part
|
|
indices = np.tril_indices(n_nodes, -1)
|
|
con = con[indices]
|
|
else:
|
|
raise ValueError("con has to be 1D or a square matrix")
|
|
|
|
# get the colormap
|
|
colormap = _get_cmap(colormap)
|
|
|
|
# Use a polar axes
|
|
if ax is None:
|
|
fig = plt.figure(figsize=(8, 8), facecolor=facecolor, layout="constrained")
|
|
ax = fig.add_subplot(polar=True)
|
|
else:
|
|
fig = ax.figure
|
|
ax.set_facecolor(facecolor)
|
|
|
|
# No ticks, we'll put our own
|
|
ax.set_xticks([])
|
|
ax.set_yticks([])
|
|
|
|
# Set y axes limit, add additional space if requested
|
|
ax.set_ylim(0, 10 + padding)
|
|
|
|
# Remove the black axes border which may obscure the labels
|
|
ax.spines["polar"].set_visible(False)
|
|
|
|
# Draw lines between connected nodes, only draw the strongest connections
|
|
if n_lines is not None and len(con) > n_lines:
|
|
con_thresh = np.sort(np.abs(con).ravel())[-n_lines]
|
|
else:
|
|
con_thresh = 0.0
|
|
|
|
# get the connections which we are drawing and sort by connection strength
|
|
# this will allow us to draw the strongest connections first
|
|
con_abs = np.abs(con)
|
|
con_draw_idx = np.where(con_abs >= con_thresh)[0]
|
|
|
|
con = con[con_draw_idx]
|
|
con_abs = con_abs[con_draw_idx]
|
|
indices = [ind[con_draw_idx] for ind in indices]
|
|
|
|
# now sort them
|
|
sort_idx = np.argsort(con_abs)
|
|
del con_abs
|
|
con = con[sort_idx]
|
|
indices = [ind[sort_idx] for ind in indices]
|
|
|
|
# Get vmin vmax for color scaling
|
|
if vmin is None:
|
|
vmin = np.min(con[np.abs(con) >= con_thresh])
|
|
if vmax is None:
|
|
vmax = np.max(con)
|
|
vrange = vmax - vmin
|
|
|
|
# We want to add some "noise" to the start and end position of the
|
|
# edges: We modulate the noise with the number of connections of the
|
|
# node and the connection strength, such that the strongest connections
|
|
# are closer to the node center
|
|
nodes_n_con = np.zeros((n_nodes), dtype=np.int64)
|
|
for i, j in zip(indices[0], indices[1]):
|
|
nodes_n_con[i] += 1
|
|
nodes_n_con[j] += 1
|
|
|
|
# initialize random number generator so plot is reproducible
|
|
rng = np.random.mtrand.RandomState(0)
|
|
|
|
n_con = len(indices[0])
|
|
noise_max = 0.25 * node_width
|
|
start_noise = rng.uniform(-noise_max, noise_max, n_con)
|
|
end_noise = rng.uniform(-noise_max, noise_max, n_con)
|
|
|
|
nodes_n_con_seen = np.zeros_like(nodes_n_con)
|
|
for i, (start, end) in enumerate(zip(indices[0], indices[1])):
|
|
nodes_n_con_seen[start] += 1
|
|
nodes_n_con_seen[end] += 1
|
|
|
|
start_noise[i] *= (nodes_n_con[start] - nodes_n_con_seen[start]) / float(
|
|
nodes_n_con[start]
|
|
)
|
|
end_noise[i] *= (nodes_n_con[end] - nodes_n_con_seen[end]) / float(
|
|
nodes_n_con[end]
|
|
)
|
|
|
|
# scale connectivity for colormap (vmin<=>0, vmax<=>1)
|
|
con_val_scaled = (con - vmin) / vrange
|
|
|
|
# Finally, we draw the connections
|
|
for pos, (i, j) in enumerate(zip(indices[0], indices[1])):
|
|
# Start point
|
|
t0, r0 = node_angles[i], 10
|
|
|
|
# End point
|
|
t1, r1 = node_angles[j], 10
|
|
|
|
# Some noise in start and end point
|
|
t0 += start_noise[pos]
|
|
t1 += end_noise[pos]
|
|
|
|
verts = [(t0, r0), (t0, 5), (t1, 5), (t1, r1)]
|
|
codes = [
|
|
m_path.Path.MOVETO,
|
|
m_path.Path.CURVE4,
|
|
m_path.Path.CURVE4,
|
|
m_path.Path.LINETO,
|
|
]
|
|
path = m_path.Path(verts, codes)
|
|
|
|
color = colormap(con_val_scaled[pos])
|
|
|
|
# Actual line
|
|
patch = m_patches.PathPatch(
|
|
path, fill=False, edgecolor=color, linewidth=linewidth, alpha=1.0
|
|
)
|
|
ax.add_patch(patch)
|
|
|
|
# Draw ring with colored nodes
|
|
height = np.ones(n_nodes) * node_height
|
|
bars = ax.bar(
|
|
node_angles,
|
|
height,
|
|
width=node_width,
|
|
bottom=9,
|
|
edgecolor=node_edgecolor,
|
|
lw=node_linewidth,
|
|
facecolor=".9",
|
|
align="center",
|
|
)
|
|
|
|
for bar, color in zip(bars, node_colors):
|
|
bar.set_facecolor(color)
|
|
|
|
# Draw node labels
|
|
angles_deg = 180 * node_angles / np.pi
|
|
for name, angle_rad, angle_deg in zip(node_names, node_angles, angles_deg):
|
|
if angle_deg >= 270:
|
|
ha = "left"
|
|
else:
|
|
# Flip the label, so text is always upright
|
|
angle_deg += 180
|
|
ha = "right"
|
|
|
|
ax.text(
|
|
angle_rad,
|
|
9.4 + node_height,
|
|
name,
|
|
size=fontsize_names,
|
|
rotation=angle_deg,
|
|
rotation_mode="anchor",
|
|
horizontalalignment=ha,
|
|
verticalalignment="center",
|
|
color=textcolor,
|
|
)
|
|
|
|
if title is not None:
|
|
ax.set_title(title, color=textcolor, fontsize=fontsize_title)
|
|
|
|
if colorbar:
|
|
sm = plt.cm.ScalarMappable(cmap=colormap, norm=plt.Normalize(vmin, vmax))
|
|
sm.set_array(np.linspace(vmin, vmax))
|
|
colorbar_kwargs = dict()
|
|
if colorbar_size is not None:
|
|
colorbar_kwargs.update(shrink=colorbar_size)
|
|
if colorbar_pos is not None:
|
|
colorbar_kwargs.update(anchor=colorbar_pos)
|
|
cb = fig.colorbar(sm, ax=ax, **colorbar_kwargs)
|
|
cb_yticks = plt.getp(cb.ax.axes, "yticklabels")
|
|
cb.ax.tick_params(labelsize=fontsize_colorbar)
|
|
plt.setp(cb_yticks, color=textcolor)
|
|
|
|
# Add callback for interaction
|
|
if interactive:
|
|
callback = partial(
|
|
_plot_connectivity_circle_onpick,
|
|
fig=fig,
|
|
ax=ax,
|
|
indices=indices,
|
|
n_nodes=n_nodes,
|
|
node_angles=node_angles,
|
|
)
|
|
|
|
fig.canvas.mpl_connect("button_press_event", callback)
|
|
|
|
plt_show(show)
|
|
return fig, ax
|
|
|
|
|
|
def plot_channel_labels_circle(labels, colors=None, picks=None, **kwargs):
|
|
"""Plot labels for each channel in a circle plot.
|
|
|
|
.. note:: This primarily makes sense for sEEG channels where each
|
|
channel can be assigned an anatomical label as the electrode
|
|
passes through various brain areas.
|
|
|
|
Parameters
|
|
----------
|
|
labels : dict
|
|
Lists of labels (values) associated with each channel (keys).
|
|
colors : dict
|
|
The color (value) for each label (key).
|
|
picks : list | tuple
|
|
The channels to consider.
|
|
**kwargs : kwargs
|
|
Keyword arguments for
|
|
:func:`mne_connectivity.viz.plot_connectivity_circle`.
|
|
|
|
Returns
|
|
-------
|
|
fig : instance of matplotlib.figure.Figure
|
|
The figure handle.
|
|
axes : instance of matplotlib.projections.polar.PolarAxes
|
|
The subplot handle.
|
|
"""
|
|
from matplotlib.colors import LinearSegmentedColormap
|
|
|
|
_validate_type(labels, dict, "labels")
|
|
_validate_type(colors, (dict, None), "colors")
|
|
_validate_type(picks, (list, tuple, None), "picks")
|
|
if picks is not None:
|
|
labels = {k: v for k, v in labels.items() if k in picks}
|
|
ch_names = list(labels.keys())
|
|
all_labels = list(set([label for val in labels.values() for label in val]))
|
|
n_labels = len(all_labels)
|
|
if colors is not None:
|
|
for label in all_labels:
|
|
if label not in colors:
|
|
raise ValueError(f"No color provided for {label} in `colors`")
|
|
# update all_labels, there may be unconnected labels in colors
|
|
all_labels = list(colors.keys())
|
|
n_labels = len(all_labels)
|
|
# make colormap
|
|
label_colors = [colors[label] for label in all_labels]
|
|
node_colors = ["black"] * len(ch_names) + label_colors
|
|
label_cmap = LinearSegmentedColormap.from_list(
|
|
"label_cmap", label_colors, N=len(label_colors)
|
|
)
|
|
else:
|
|
node_colors = None
|
|
|
|
node_names = ch_names + all_labels
|
|
con = np.zeros((len(node_names), len(node_names))) * np.nan
|
|
for idx, ch_name in enumerate(ch_names):
|
|
for label in labels[ch_name]:
|
|
node_idx = node_names.index(label)
|
|
label_color = all_labels.index(label) / n_labels
|
|
con[idx, node_idx] = con[node_idx, idx] = label_color # symmetric
|
|
# plot
|
|
node_order = ch_names + all_labels[::-1]
|
|
node_angles = circular_layout(
|
|
node_names, node_order, start_pos=90, group_boundaries=[0, len(ch_names)]
|
|
)
|
|
# provide defaults but don't overwrite
|
|
if "node_angles" not in kwargs:
|
|
kwargs.update(node_angles=node_angles)
|
|
if "colorbar" not in kwargs:
|
|
kwargs.update(colorbar=False)
|
|
if "node_colors" not in kwargs:
|
|
kwargs.update(node_colors=node_colors)
|
|
if "vmin" not in kwargs:
|
|
kwargs.update(vmin=0)
|
|
if "vmax" not in kwargs:
|
|
kwargs.update(vmax=1)
|
|
if "colormap" not in kwargs:
|
|
kwargs.update(colormap=label_cmap)
|
|
return _plot_connectivity_circle(con, node_names, **kwargs)
|