"""Functions to plot on circle as for connectivity.""" # Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. from functools import partial from itertools import cycle import numpy as np from ..utils import _validate_type from .utils import _get_cmap, plt_show def circular_layout( node_names, node_order, start_pos=90, start_between=True, group_boundaries=None, group_sep=10, ): """Create layout arranging nodes on a circle. Parameters ---------- node_names : list of str Node names. node_order : list of str List with node names defining the order in which the nodes are arranged. Must have the elements as node_names but the order can be different. The nodes are arranged clockwise starting at "start_pos" degrees. start_pos : float Angle in degrees that defines where the first node is plotted. start_between : bool If True, the layout starts with the position between the nodes. This is the same as adding "180. / len(node_names)" to start_pos. group_boundaries : None | array-like List of of boundaries between groups at which point a "group_sep" will be inserted. E.g. "[0, len(node_names) / 2]" will create two groups. group_sep : float Group separation angle in degrees. See "group_boundaries". Returns ------- node_angles : array, shape=(n_node_names,) Node angles in degrees. """ n_nodes = len(node_names) if len(node_order) != n_nodes: raise ValueError("node_order has to be the same length as node_names") if group_boundaries is not None: boundaries = np.array(group_boundaries, dtype=np.int64) if np.any(boundaries >= n_nodes) or np.any(boundaries < 0): raise ValueError('"group_boundaries" has to be between 0 and n_nodes - 1.') if len(boundaries) > 1 and np.any(np.diff(boundaries) <= 0): raise ValueError('"group_boundaries" must have non-decreasing values.') n_group_sep = len(group_boundaries) else: n_group_sep = 0 boundaries = None # convert it to a list with indices node_order = [node_order.index(name) for name in node_names] node_order = np.array(node_order) if len(np.unique(node_order)) != n_nodes: raise ValueError("node_order has repeated entries") node_sep = (360.0 - n_group_sep * group_sep) / n_nodes if start_between: start_pos += node_sep / 2 if boundaries is not None and boundaries[0] == 0: # special case when a group separator is at the start start_pos += group_sep / 2 boundaries = boundaries[1:] if n_group_sep > 1 else None node_angles = np.ones(n_nodes, dtype=np.float64) * node_sep node_angles[0] = start_pos if boundaries is not None: node_angles[boundaries] += group_sep node_angles = np.cumsum(node_angles)[node_order] return node_angles def _plot_connectivity_circle_onpick( event, fig=None, ax=None, indices=None, n_nodes=0, node_angles=None, ylim=(9, 10), ): """Isolate connections around a single node when user left clicks a node. On right click, resets all connections. """ if event.inaxes != ax: return if event.button == 1: # left click # click must be near node radius if not ylim[0] <= event.ydata <= ylim[1]: return # all angles in range [0, 2*pi] node_angles = node_angles % (np.pi * 2) node = np.argmin(np.abs(event.xdata - node_angles)) patches = event.inaxes.patches for ii, (x, y) in enumerate(zip(indices[0], indices[1])): patches[ii].set_visible(node in [x, y]) fig.canvas.draw() elif event.button == 3: # right click patches = event.inaxes.patches for ii in range(np.size(indices, axis=1)): patches[ii].set_visible(True) fig.canvas.draw() def _plot_connectivity_circle( con, node_names, indices=None, n_lines=None, node_angles=None, node_width=None, node_height=None, node_colors=None, facecolor="black", textcolor="white", node_edgecolor="black", linewidth=1.5, colormap="hot", vmin=None, vmax=None, colorbar=True, title=None, colorbar_size=None, colorbar_pos=None, fontsize_title=12, fontsize_names=8, fontsize_colorbar=8, padding=6.0, ax=None, interactive=True, node_linewidth=2.0, show=True, ): import matplotlib.patches as m_patches import matplotlib.path as m_path import matplotlib.pyplot as plt from matplotlib.projections.polar import PolarAxes _validate_type(ax, (None, PolarAxes)) n_nodes = len(node_names) if node_angles is not None: if len(node_angles) != n_nodes: raise ValueError("node_angles has to be the same length as node_names") # convert it to radians node_angles = node_angles * np.pi / 180 else: # uniform layout on unit circle node_angles = np.linspace(0, 2 * np.pi, n_nodes, endpoint=False) if node_width is None: # widths correspond to the minimum angle between two nodes dist_mat = node_angles[None, :] - node_angles[:, None] dist_mat[np.diag_indices(n_nodes)] = 1e9 node_width = np.min(np.abs(dist_mat)) else: node_width = node_width * np.pi / 180 if node_height is None: node_height = 1.0 if node_colors is not None: if len(node_colors) < n_nodes: node_colors = cycle(node_colors) else: # assign colors using colormap try: spectral = plt.cm.spectral except AttributeError: spectral = plt.cm.Spectral node_colors = [spectral(i / float(n_nodes)) for i in range(n_nodes)] # handle 1D and 2D connectivity information if con.ndim == 1: if indices is None: raise ValueError("indices has to be provided if con.ndim == 1") elif con.ndim == 2: if con.shape[0] != n_nodes or con.shape[1] != n_nodes: raise ValueError("con has to be 1D or a square matrix") # we use the lower-triangular part indices = np.tril_indices(n_nodes, -1) con = con[indices] else: raise ValueError("con has to be 1D or a square matrix") # get the colormap colormap = _get_cmap(colormap) # Use a polar axes if ax is None: fig = plt.figure(figsize=(8, 8), facecolor=facecolor, layout="constrained") ax = fig.add_subplot(polar=True) else: fig = ax.figure ax.set_facecolor(facecolor) # No ticks, we'll put our own ax.set_xticks([]) ax.set_yticks([]) # Set y axes limit, add additional space if requested ax.set_ylim(0, 10 + padding) # Remove the black axes border which may obscure the labels ax.spines["polar"].set_visible(False) # Draw lines between connected nodes, only draw the strongest connections if n_lines is not None and len(con) > n_lines: con_thresh = np.sort(np.abs(con).ravel())[-n_lines] else: con_thresh = 0.0 # get the connections which we are drawing and sort by connection strength # this will allow us to draw the strongest connections first con_abs = np.abs(con) con_draw_idx = np.where(con_abs >= con_thresh)[0] con = con[con_draw_idx] con_abs = con_abs[con_draw_idx] indices = [ind[con_draw_idx] for ind in indices] # now sort them sort_idx = np.argsort(con_abs) del con_abs con = con[sort_idx] indices = [ind[sort_idx] for ind in indices] # Get vmin vmax for color scaling if vmin is None: vmin = np.min(con[np.abs(con) >= con_thresh]) if vmax is None: vmax = np.max(con) vrange = vmax - vmin # We want to add some "noise" to the start and end position of the # edges: We modulate the noise with the number of connections of the # node and the connection strength, such that the strongest connections # are closer to the node center nodes_n_con = np.zeros((n_nodes), dtype=np.int64) for i, j in zip(indices[0], indices[1]): nodes_n_con[i] += 1 nodes_n_con[j] += 1 # initialize random number generator so plot is reproducible rng = np.random.mtrand.RandomState(0) n_con = len(indices[0]) noise_max = 0.25 * node_width start_noise = rng.uniform(-noise_max, noise_max, n_con) end_noise = rng.uniform(-noise_max, noise_max, n_con) nodes_n_con_seen = np.zeros_like(nodes_n_con) for i, (start, end) in enumerate(zip(indices[0], indices[1])): nodes_n_con_seen[start] += 1 nodes_n_con_seen[end] += 1 start_noise[i] *= (nodes_n_con[start] - nodes_n_con_seen[start]) / float( nodes_n_con[start] ) end_noise[i] *= (nodes_n_con[end] - nodes_n_con_seen[end]) / float( nodes_n_con[end] ) # scale connectivity for colormap (vmin<=>0, vmax<=>1) con_val_scaled = (con - vmin) / vrange # Finally, we draw the connections for pos, (i, j) in enumerate(zip(indices[0], indices[1])): # Start point t0, r0 = node_angles[i], 10 # End point t1, r1 = node_angles[j], 10 # Some noise in start and end point t0 += start_noise[pos] t1 += end_noise[pos] verts = [(t0, r0), (t0, 5), (t1, 5), (t1, r1)] codes = [ m_path.Path.MOVETO, m_path.Path.CURVE4, m_path.Path.CURVE4, m_path.Path.LINETO, ] path = m_path.Path(verts, codes) color = colormap(con_val_scaled[pos]) # Actual line patch = m_patches.PathPatch( path, fill=False, edgecolor=color, linewidth=linewidth, alpha=1.0 ) ax.add_patch(patch) # Draw ring with colored nodes height = np.ones(n_nodes) * node_height bars = ax.bar( node_angles, height, width=node_width, bottom=9, edgecolor=node_edgecolor, lw=node_linewidth, facecolor=".9", align="center", ) for bar, color in zip(bars, node_colors): bar.set_facecolor(color) # Draw node labels angles_deg = 180 * node_angles / np.pi for name, angle_rad, angle_deg in zip(node_names, node_angles, angles_deg): if angle_deg >= 270: ha = "left" else: # Flip the label, so text is always upright angle_deg += 180 ha = "right" ax.text( angle_rad, 9.4 + node_height, name, size=fontsize_names, rotation=angle_deg, rotation_mode="anchor", horizontalalignment=ha, verticalalignment="center", color=textcolor, ) if title is not None: ax.set_title(title, color=textcolor, fontsize=fontsize_title) if colorbar: sm = plt.cm.ScalarMappable(cmap=colormap, norm=plt.Normalize(vmin, vmax)) sm.set_array(np.linspace(vmin, vmax)) colorbar_kwargs = dict() if colorbar_size is not None: colorbar_kwargs.update(shrink=colorbar_size) if colorbar_pos is not None: colorbar_kwargs.update(anchor=colorbar_pos) cb = fig.colorbar(sm, ax=ax, **colorbar_kwargs) cb_yticks = plt.getp(cb.ax.axes, "yticklabels") cb.ax.tick_params(labelsize=fontsize_colorbar) plt.setp(cb_yticks, color=textcolor) # Add callback for interaction if interactive: callback = partial( _plot_connectivity_circle_onpick, fig=fig, ax=ax, indices=indices, n_nodes=n_nodes, node_angles=node_angles, ) fig.canvas.mpl_connect("button_press_event", callback) plt_show(show) return fig, ax def plot_channel_labels_circle(labels, colors=None, picks=None, **kwargs): """Plot labels for each channel in a circle plot. .. note:: This primarily makes sense for sEEG channels where each channel can be assigned an anatomical label as the electrode passes through various brain areas. Parameters ---------- labels : dict Lists of labels (values) associated with each channel (keys). colors : dict The color (value) for each label (key). picks : list | tuple The channels to consider. **kwargs : kwargs Keyword arguments for :func:`mne_connectivity.viz.plot_connectivity_circle`. Returns ------- fig : instance of matplotlib.figure.Figure The figure handle. axes : instance of matplotlib.projections.polar.PolarAxes The subplot handle. """ from matplotlib.colors import LinearSegmentedColormap _validate_type(labels, dict, "labels") _validate_type(colors, (dict, None), "colors") _validate_type(picks, (list, tuple, None), "picks") if picks is not None: labels = {k: v for k, v in labels.items() if k in picks} ch_names = list(labels.keys()) all_labels = list(set([label for val in labels.values() for label in val])) n_labels = len(all_labels) if colors is not None: for label in all_labels: if label not in colors: raise ValueError(f"No color provided for {label} in `colors`") # update all_labels, there may be unconnected labels in colors all_labels = list(colors.keys()) n_labels = len(all_labels) # make colormap label_colors = [colors[label] for label in all_labels] node_colors = ["black"] * len(ch_names) + label_colors label_cmap = LinearSegmentedColormap.from_list( "label_cmap", label_colors, N=len(label_colors) ) else: node_colors = None node_names = ch_names + all_labels con = np.zeros((len(node_names), len(node_names))) * np.nan for idx, ch_name in enumerate(ch_names): for label in labels[ch_name]: node_idx = node_names.index(label) label_color = all_labels.index(label) / n_labels con[idx, node_idx] = con[node_idx, idx] = label_color # symmetric # plot node_order = ch_names + all_labels[::-1] node_angles = circular_layout( node_names, node_order, start_pos=90, group_boundaries=[0, len(ch_names)] ) # provide defaults but don't overwrite if "node_angles" not in kwargs: kwargs.update(node_angles=node_angles) if "colorbar" not in kwargs: kwargs.update(colorbar=False) if "node_colors" not in kwargs: kwargs.update(node_colors=node_colors) if "vmin" not in kwargs: kwargs.update(vmin=0) if "vmax" not in kwargs: kwargs.update(vmax=1) if "colormap" not in kwargs: kwargs.update(colormap=label_cmap) return _plot_connectivity_circle(con, node_names, **kwargs)