# Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. import copy import os import os.path as op import time import traceback import warnings from functools import partial from io import BytesIO import numpy as np from scipy.interpolate import interp1d from scipy.sparse import csr_array from scipy.spatial.distance import cdist from ..._fiff.meas_info import Info from ..._fiff.pick import pick_types from ..._freesurfer import ( _estimate_talxfm_rigid, _get_aseg, _get_head_surface, _get_skull_surface, read_freesurfer_lut, read_talxfm, vertex_to_mni, ) from ...defaults import DEFAULTS, _handle_default from ...surface import _marching_cubes, _mesh_borders, mesh_edges from ...transforms import ( Transform, _frame_to_str, _get_trans, _get_transforms_to_coord_frame, apply_trans, ) from ...utils import ( Bunch, _auto_weakref, _check_fname, _check_option, _ensure_int, _path_like, _ReuseCycle, _to_rgb, _validate_type, fill_doc, get_subjects_dir, logger, use_log_level, verbose, warn, ) from .._3d import ( _check_views, _handle_sensor_types, _handle_time, _plot_forward, _plot_helmet, _plot_sensors_3d, _process_clim, ) from .._3d_overlay import _LayeredMesh from ..ui_events import ( ColormapRange, PlaybackSpeed, TimeChange, VertexSelect, _get_event_channel, disable_ui_events, publish, subscribe, unsubscribe, ) from ..utils import ( _generate_default_filename, _get_color_list, _save_ndarray_img, _show_help_fig, concatenate_images, safe_event, ) from .colormap import calculate_lut from .surface import _Surface from .view import _lh_views_dict, views_dicts @fill_doc class Brain: """Class for visualizing a brain. .. warning:: The API for this class is not currently complete. We suggest using :meth:`mne.viz.plot_source_estimates` with the PyVista backend enabled to obtain a ``Brain`` instance. Parameters ---------- subject : str Subject name in Freesurfer subjects dir. .. versionchanged:: 1.2 This parameter was renamed from ``subject_id`` to ``subject``. hemi : str Hemisphere id (ie 'lh', 'rh', 'both', or 'split'). In the case of 'both', both hemispheres are shown in the same window. In the case of 'split' hemispheres are displayed side-by-side in different viewing panes. surf : str FreeSurfer surface mesh name (ie 'white', 'inflated', etc.). title : str Title for the window. cortex : str, list, dict Specifies how the cortical surface is rendered. Options: 1. The name of one of the preset cortex styles: ``'classic'`` (default), ``'high_contrast'``, ``'low_contrast'``, or ``'bone'``. 2. A single color-like argument to render the cortex as a single color, e.g. ``'red'`` or ``(0.1, 0.4, 1.)``. 3. A list of two color-like used to render binarized curvature values for gyral (first) and sulcal (second). regions, e.g., ``['red', 'blue']`` or ``[(1, 0, 0), (0, 0, 1)]``. 4. A dict containing keys ``'vmin', 'vmax', 'colormap'`` with values used to render the binarized curvature (where 0 is gyral, 1 is sulcal). .. versionchanged:: 0.24 Add support for non-string arguments. alpha : float in [0, 1] Alpha level to control opacity of the cortical surface. size : int | array-like, shape (2,) The size of the window, in pixels. can be one number to specify a square window, or a length-2 sequence to specify (width, height). background : tuple(int, int, int) The color definition of the background: (red, green, blue). foreground : matplotlib color Color of the foreground (will be used for colorbars and text). None (default) will use black or white depending on the value of ``background``. figure : list of Figure | None If None (default), a new window will be created with the appropriate views. subjects_dir : str | None If not None, this directory will be used as the subjects directory instead of the value set using the SUBJECTS_DIR environment variable. %(views)s offset : bool | str If True, shifts the right- or left-most x coordinate of the left and right surfaces, respectively, to be at zero. This is useful for viewing inflated surface where hemispheres typically overlap. Can be "auto" (default) use True with inflated surfaces and False otherwise (Default: 'auto'). Only used when ``hemi='both'``. .. versionchanged:: 0.23 Default changed to "auto". offscreen : bool Deprecated and will be removed in 1.9, do not use. interaction : str Can be "trackball" (default) or "terrain", i.e. a turntable-style camera. units : str Can be 'm' or 'mm' (default). %(view_layout)s silhouette : dict | bool As a dict, it contains the ``color``, ``linewidth``, ``alpha`` opacity and ``decimate`` (level of decimation between 0 and 1 or None) of the brain's silhouette to display. If True, the default values are used and if False, no silhouette will be displayed. Defaults to False. %(theme_3d)s show : bool Display the window as soon as it is ready. Defaults to True. block : bool Deprecated and will be removed in 1.9, do not use. Consider using :func:`matplotlib.pyplot.show` with ``block=True`` instead. Attributes ---------- geo : dict A dictionary of PyVista surface objects for each hemisphere. overlays : dict The overlays. Notes ----- The figure will publish and subscribe to the following UI events: * :class:`~mne.viz.ui_events.TimeChange` * :class:`~mne.viz.ui_events.PlaybackSpeed` * :class:`~mne.viz.ui_events.ColormapRange`, ``kind="distributed_source_power"`` * :class:`~mne.viz.ui_events.VertexSelect` This table shows the capabilities of each Brain backend ("✓" for full support, and "-" for partial support): .. table:: :widths: auto +-------------------------------------+--------------+---------------+ | 3D function: | surfer.Brain | mne.viz.Brain | +=====================================+==============+===============+ | :meth:`add_annotation` | ✓ | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`add_data` | ✓ | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`add_dipole` | | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`add_foci` | ✓ | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`add_forward` | | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`add_head` | | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`add_label` | ✓ | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`add_sensors` | | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`add_skull` | | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`add_text` | ✓ | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`add_volume_labels` | | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`close` | ✓ | ✓ | +-------------------------------------+--------------+---------------+ | data | ✓ | ✓ | +-------------------------------------+--------------+---------------+ | foci | ✓ | | +-------------------------------------+--------------+---------------+ | labels | ✓ | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`remove_data` | | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`remove_dipole` | | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`remove_forward` | | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`remove_head` | | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`remove_labels` | ✓ | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`remove_annotations` | - | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`remove_sensors` | | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`remove_skull` | | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`remove_text` | | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`remove_volume_labels` | | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`save_image` | ✓ | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`save_movie` | ✓ | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`screenshot` | ✓ | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`show_view` | ✓ | ✓ | +-------------------------------------+--------------+---------------+ | TimeViewer | ✓ | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`get_picked_points` | | ✓ | +-------------------------------------+--------------+---------------+ | :meth:`add_data(volume) ` | | ✓ | +-------------------------------------+--------------+---------------+ | view_layout | | ✓ | +-------------------------------------+--------------+---------------+ | flatmaps | | ✓ | +-------------------------------------+--------------+---------------+ | vertex picking | | ✓ | +-------------------------------------+--------------+---------------+ | label picking | | ✓ | +-------------------------------------+--------------+---------------+ """ def __init__( self, subject, hemi="both", surf="pial", title=None, cortex="classic", alpha=1.0, size=800, background="black", foreground=None, figure=None, subjects_dir=None, views="auto", *, offset="auto", offscreen=None, interaction="trackball", units="mm", view_layout="vertical", silhouette=False, theme=None, show=True, block=None, ): from ..backends.renderer import _get_renderer, backend _validate_type(subject, str, "subject") self._surf = surf if offscreen is not None: warn( "The 'offscreen' parameter is deprecated and will be removed in 1.9. " "as it has no effect", FutureWarning, ) if hemi is None: hemi = "vol" hemi = self._check_hemi(hemi, extras=("both", "split", "vol")) if hemi in ("both", "split"): self._hemis = ("lh", "rh") else: assert hemi in ("lh", "rh", "vol") self._hemis = (hemi,) self._view_layout = _check_option( "view_layout", view_layout, ("vertical", "horizontal") ) if figure is not None and not isinstance(figure, int): backend._check_3d_figure(figure) if title is None: self._title = subject else: self._title = title self._interaction = "trackball" self._bg_color = _to_rgb(background, name="background") if foreground is None: foreground = "w" if sum(self._bg_color) < 2 else "k" self._fg_color = _to_rgb(foreground, name="foreground") del background, foreground views = _check_views(surf, views, hemi) col_dict = dict(lh=1, rh=1, both=1, split=2, vol=1) shape = (len(views), col_dict[hemi]) if self._view_layout == "horizontal": shape = shape[::-1] self._subplot_shape = shape size = tuple(np.atleast_1d(size).round(0).astype(int).flat) if len(size) not in (1, 2): raise ValueError( '"size" parameter must be an int or length-2 sequence of ints.' ) size = size if len(size) == 2 else size * 2 # 1-tuple to 2-tuple subjects_dir = get_subjects_dir(subjects_dir) if subjects_dir is not None: subjects_dir = str(subjects_dir) if block is not None: warn( "block is deprecated and will be removed in 1.9, use " "plt.show(block=True) instead" ) self.time_viewer = False self._hash = time.time_ns() self._block = block self._hemi = hemi self._units = units self._alpha = float(alpha) self._subject = subject self._subjects_dir = subjects_dir self._views = views self._times = None self._vertex_to_label_id = dict() self._annotation_labels = dict() self._labels = {"lh": list(), "rh": list()} self._unnamed_label_id = 0 # can only grow self._annots = {"lh": list(), "rh": list()} self._layered_meshes = dict() self._actors = dict() self._cleaned = False # default values for silhouette self._silhouette = { "color": self._bg_color, "line_width": 2, "alpha": alpha, "decimate": 0.9, } _validate_type(silhouette, (dict, bool), "silhouette") if isinstance(silhouette, dict): self._silhouette.update(silhouette) self.silhouette = True else: self.silhouette = silhouette self._scalar_bar = None # for now only one time label can be added # since it is the same for all figures self._time_label_added = False # array of data used by TimeViewer self._data = {} self.geo = {} self.set_time_interpolation("nearest") geo_kwargs = self._cortex_colormap(cortex) # evaluate at the midpoint of the used colormap val = -geo_kwargs["vmin"] / (geo_kwargs["vmax"] - geo_kwargs["vmin"]) self._brain_color = geo_kwargs["colormap"](val) # load geometry for one or both hemispheres as necessary _validate_type(offset, (str, bool), "offset") if isinstance(offset, str): _check_option("offset", offset, ("auto",), extra="when str") offset = surf in ("inflated", "flat") offset = None if (not offset or hemi != "both") else 0.0 logger.debug(f"Hemi offset: {offset}") _validate_type(theme, (str, None), "theme") self._renderer = _get_renderer( name=self._title, size=size, bgcolor=self._bg_color, shape=shape, fig=figure ) self._renderer._window_close_connect(self._clean) self._renderer._window_set_theme(theme) self.plotter = self._renderer.plotter self.widgets = dict() self._setup_canonical_rotation() # plot hemis for h in ("lh", "rh"): if h not in self._hemis: continue # don't make surface if not chosen # Initialize a Surface object as the geometry geo = _Surface( self._subject, h, surf, self._subjects_dir, offset, units=self._units, x_dir=self._rigid[0, :3], ) # Load in the geometry and curvature geo.load_geometry() geo.load_curvature() self.geo[h] = geo for _, _, v in self._iter_views(h): if self._layered_meshes.get(h) is None: mesh = _LayeredMesh( renderer=self._renderer, vertices=self.geo[h].coords, triangles=self.geo[h].faces, normals=self.geo[h].nn, ) mesh.map() # send to GPU if self.geo[h].bin_curv is None: scalars = mesh._default_scalars[:, 0] else: scalars = self.geo[h].bin_curv mesh.add_overlay( scalars=scalars, colormap=geo_kwargs["colormap"], rng=[geo_kwargs["vmin"], geo_kwargs["vmax"]], opacity=alpha, name="curv", ) self._layered_meshes[h] = mesh # add metadata to the mesh for picking mesh._polydata._hemi = h else: actor = self._layered_meshes[h]._actor self._renderer.plotter.add_actor(actor, render=False) if self.silhouette: mesh = self._layered_meshes[h] self._renderer._silhouette( mesh=mesh._polydata, color=self._silhouette["color"], line_width=self._silhouette["line_width"], alpha=self._silhouette["alpha"], decimate=self._silhouette["decimate"], ) self._set_camera(**views_dicts[h][v]) self.interaction = interaction self._closed = False if show: self.show() # update the views once the geometry is all set for h in self._hemis: for ri, ci, v in self._iter_views(h): self.show_view(v, row=ri, col=ci, hemi=h, update=False) if surf == "flat": self._renderer.set_interaction("rubber_band_2d") self._renderer._update() def _setup_canonical_rotation(self): self._rigid = np.eye(4) try: xfm = _estimate_talxfm_rigid(self._subject, self._subjects_dir) except Exception: logger.info( "Could not estimate rigid Talairach alignment, using identity matrix" ) else: self._rigid[:] = xfm def setup_time_viewer(self, time_viewer=True, show_traces=True): """Configure the time viewer parameters. Parameters ---------- time_viewer : bool If True, enable widgets interaction. Defaults to True. show_traces : bool If True, enable visualization of time traces. Defaults to True. Notes ----- The keyboard shortcuts are the following: '?': Display help window 'i': Toggle interface 's': Apply auto-scaling 'r': Restore original clim 'c': Clear all traces 'n': Shift the time forward by the playback speed 'b': Shift the time backward by the playback speed 'Space': Start/Pause playback 'Up': Decrease camera elevation angle 'Down': Increase camera elevation angle 'Left': Decrease camera azimuth angle 'Right': Increase camera azimuth angle """ from ..backends._utils import _qt_app_exec if self.time_viewer: return if not self._data: raise ValueError("No data to visualize. See ``add_data``.") self.time_viewer = time_viewer self.orientation = list(_lh_views_dict.keys()) self.default_smoothing_range = [-1, 15] # Default configuration self.visibility = False self.default_playback_speed_range = [0.01, 1] self.default_playback_speed_value = 0.01 self.default_status_bar_msg = "Press ? for help" self.default_label_extract_modes = { "stc": ["mean", "max"], "src": ["mean_flip", "pca_flip", "auto"], } self.annot = None self.label_extract_mode = None all_keys = ("lh", "rh", "vol") self.act_data_smooth = {key: (None, None) for key in all_keys} self.color_list = _get_color_list() # remove grey for better contrast on the brain self.color_list.remove("#7f7f7f") self.color_cycle = _ReuseCycle(self.color_list) self.mpl_canvas = None self.help_canvas = None self.rms = None self.picked_patches = {key: list() for key in all_keys} self.picked_points = {key: list() for key in all_keys} self.pick_table = dict() self._spheres = list() self._mouse_no_mvt = -1 # Derived parameters: self.playback_speed = self.default_playback_speed_value _validate_type(show_traces, (bool, str, "numeric"), "show_traces") self.interactor_fraction = 0.25 if isinstance(show_traces, str): self.show_traces = True self.separate_canvas = False self.traces_mode = "vertex" if show_traces == "separate": self.separate_canvas = True elif show_traces == "label": self.traces_mode = "label" else: assert show_traces == "vertex" # guaranteed above else: if isinstance(show_traces, bool): self.show_traces = show_traces else: show_traces = float(show_traces) if not 0 < show_traces < 1: raise ValueError( "show traces, if numeric, must be between 0 and 1, " f"got {show_traces}" ) self.show_traces = True self.interactor_fraction = show_traces self.traces_mode = "vertex" self.separate_canvas = False del show_traces self._configure_time_label() self._configure_scalar_bar() self._configure_shortcuts() self._configure_picking() self._configure_dock() self._configure_tool_bar() self._configure_menu() self._configure_status_bar() self._configure_help() # show everything at the end self.toggle_interface() self._renderer.show() # sizes could change, update views for hemi in ("lh", "rh"): for ri, ci, v in self._iter_views(hemi): self.show_view(view=v, row=ri, col=ci) self._renderer._process_events() self._renderer._update() # finally, show the MplCanvas if self.show_traces: self.mpl_canvas.show() if self._block: _qt_app_exec(self._renderer.figure.store["app"]) @safe_event def _clean(self): # resolve the reference cycle self._renderer._window_close_disconnect() self.clear_glyphs() self.remove_annotations() # clear init actors for hemi in self._layered_meshes: self._layered_meshes[hemi]._clean() self._clear_callbacks() self._clear_widgets() if getattr(self, "mpl_canvas", None) is not None: self.mpl_canvas.clear() if getattr(self, "act_data_smooth", None) is not None: for key in list(self.act_data_smooth.keys()): self.act_data_smooth[key] = None # XXX this should be done in PyVista for renderer in self._renderer._all_renderers: renderer.RemoveAllLights() # app_window cannot be set to None because it is used in __del__ for key in ("lighting", "interactor", "_RenderWindow"): setattr(self.plotter, key, None) # Qt LeaveEvent requires _Iren so we use _FakeIren instead of None # to resolve the ref to vtkGenericRenderWindowInteractor self.plotter._Iren = _FakeIren() if getattr(self.plotter, "picker", None) is not None: self.plotter.picker = None # XXX end PyVista for key in ( "plotter", "window", "dock", "tool_bar", "menu_bar", "interactor", "mpl_canvas", "time_actor", "picked_renderer", "act_data_smooth", "_scalar_bar", "actions", "widgets", "geo", "_data", ): setattr(self, key, None) self._cleaned = True def toggle_interface(self, value=None): """Toggle the interface. Parameters ---------- value : bool | None If True, the widgets are shown and if False, they are hidden. If None, the state of the widgets is toggled. Defaults to None. """ if value is None: self.visibility = not self.visibility else: self.visibility = value # update tool bar and dock with self._renderer._window_ensure_minimum_sizes(): if self.visibility: self._renderer._dock_show() self._renderer._tool_bar_update_button_icon( name="visibility", icon_name="visibility_on" ) else: self._renderer._dock_hide() self._renderer._tool_bar_update_button_icon( name="visibility", icon_name="visibility_off" ) self._renderer._update() def apply_auto_scaling(self): """Detect automatically fitting scaling parameters.""" self._update_auto_scaling() def restore_user_scaling(self): """Restore original scaling parameters.""" self._update_auto_scaling(restore=True) def toggle_playback(self, value=None): """Toggle time playback. Parameters ---------- value : bool | None If True, automatic time playback is enabled and if False, it's disabled. If None, the state of time playback is toggled. Defaults to None. """ self._renderer._toggle_playback(value) def reset(self): """Reset view, current time and time step.""" self.reset_view() self._renderer._reset_time() def set_playback_speed(self, speed): """Set the time playback speed. Parameters ---------- speed : float The speed of the playback. """ publish(self, PlaybackSpeed(speed=speed)) def _configure_time_label(self): self.time_actor = self._data.get("time_actor") if self.time_actor is not None: self.time_actor.SetPosition(0.5, 0.03) self.time_actor.GetTextProperty().SetJustificationToCentered() self.time_actor.GetTextProperty().BoldOn() def _configure_scalar_bar(self): if self._scalar_bar is not None: self._scalar_bar.SetOrientationToVertical() self._scalar_bar.SetHeight(0.6) self._scalar_bar.SetWidth(0.05) self._scalar_bar.SetPosition(0.02, 0.2) def _configure_dock_playback_widget(self, name): len_time = len(self._data["time"]) - 1 # Time widget if len_time < 1: self.widgets["time"] = None self.widgets["min_time"] = None self.widgets["max_time"] = None self.widgets["current_time"] = None else: @_auto_weakref def current_time_func(): return self._current_time self._renderer._enable_time_interaction( self, current_time_func, self._data["time"], self.default_playback_speed_value, self.default_playback_speed_range, ) # Time label current_time = self._current_time assert current_time is not None # should never be the case, float time_label = self._data["time_label"] if callable(time_label): current_time = time_label(current_time) else: current_time = time_label if self.time_actor is not None: self.time_actor.SetInput(current_time) del current_time def _configure_dock_orientation_widget(self, name): layout = self._renderer._dock_add_group_box(name) # Renderer widget rends = [str(i) for i in range(len(self._renderer._all_renderers))] if len(rends) > 1: @_auto_weakref def select_renderer(idx): idx = int(idx) loc = self._renderer._index_to_loc(idx) self.plotter.subplot(*loc) self.widgets["renderer"] = self._renderer._dock_add_combo_box( name="Renderer", value="0", rng=rends, callback=select_renderer, layout=layout, ) # Use 'lh' as a reference for orientation for 'both' if self._hemi == "both": hemis_ref = ["lh"] else: hemis_ref = self._hemis orientation_data = [None] * len(rends) for hemi in hemis_ref: for ri, ci, v in self._iter_views(hemi): idx = self._renderer._loc_to_index((ri, ci)) if v == "flat": _data = None else: _data = dict(default=v, hemi=hemi, row=ri, col=ci) orientation_data[idx] = _data @_auto_weakref def set_orientation(value, orientation_data=orientation_data): if "renderer" in self.widgets: idx = int(self.widgets["renderer"].get_value()) else: idx = 0 if orientation_data[idx] is not None: self.show_view( value, row=orientation_data[idx]["row"], col=orientation_data[idx]["col"], hemi=orientation_data[idx]["hemi"], ) self.widgets["orientation"] = self._renderer._dock_add_combo_box( name=None, value=self.orientation[0], rng=self.orientation, callback=set_orientation, layout=layout, ) def _configure_dock_colormap_widget(self, name): fmax, fscale, fscale_power = _get_range(self) rng = [0, fmax * fscale] self._data["fscale"] = fscale layout = self._renderer._dock_add_group_box(name) text = "min / mid / max" if fscale_power != 0: text += f" (×1e{fscale_power:d})" self._renderer._dock_add_label( value=text, align=True, layout=layout, ) @_auto_weakref def update_single_lut_value(value, key): # Called by the sliders and spin boxes. self.update_lut(**{key: value / self._data["fscale"]}) keys = ("fmin", "fmid", "fmax") for key in keys: hlayout = self._renderer._dock_add_layout(vertical=False) self.widgets[key] = self._renderer._dock_add_slider( name=None, value=self._data[key] * self._data["fscale"], rng=rng, callback=partial(update_single_lut_value, key=key), double=True, layout=hlayout, ) self.widgets[f"entry_{key}"] = self._renderer._dock_add_spin_box( name=None, value=self._data[key] * self._data["fscale"], callback=partial(update_single_lut_value, key=key), rng=rng, layout=hlayout, ) self._renderer._layout_add_widget(layout, hlayout) # reset / minus / plus hlayout = self._renderer._dock_add_layout(vertical=False) self._renderer._dock_add_label( value="Rescale", align=True, layout=hlayout, ) self.widgets["reset"] = self._renderer._dock_add_button( name="↺", callback=self.restore_user_scaling, layout=hlayout, style="toolbutton", ) @_auto_weakref def fminus(): self._update_fscale(1.2**-0.25) self.widgets["fminus"] = self._renderer._dock_add_button( name="➖", callback=fminus, layout=hlayout, style="toolbutton", ) @_auto_weakref def fplus(): self._update_fscale(1.2**0.25) self.widgets["fplus"] = self._renderer._dock_add_button( name="➕", callback=fplus, layout=hlayout, style="toolbutton", ) self._renderer._layout_add_widget(layout, hlayout) def _configure_dock_trace_widget(self, name): if not self.show_traces: return # do not show trace mode for volumes if ( self._data.get("src", None) is not None and self._data["src"].kind == "volume" ): self._configure_vertex_time_course() return layout = self._renderer._dock_add_group_box(name) # setup candidate annots @_auto_weakref def _set_annot(annot): self.clear_glyphs() self.remove_labels() self.remove_annotations() self.annot = annot if annot == "None": self.traces_mode = "vertex" self._configure_vertex_time_course() else: self.traces_mode = "label" self._configure_label_time_course() self._renderer._update() # setup label extraction parameters @_auto_weakref def _set_label_mode(mode): if self.traces_mode != "label": return glyphs = copy.deepcopy(self.picked_patches) self.label_extract_mode = mode self.clear_glyphs() for hemi in self._hemis: for label_id in glyphs[hemi]: label = self._annotation_labels[hemi][label_id] vertex_id = label.vertices[0] self._add_label_glyph(hemi, None, vertex_id) self.mpl_canvas.axes.relim() self.mpl_canvas.axes.autoscale_view() self.mpl_canvas.update_plot() self._renderer._update() from ...label import _read_annot_cands from ...source_estimate import _get_allowed_label_modes dir_name = op.join(self._subjects_dir, self._subject, "label") cands = _read_annot_cands(dir_name, raise_error=False) cands = cands + ["None"] self.annot = cands[0] stc = self._data["stc"] modes = _get_allowed_label_modes(stc) if self._data["src"] is None: modes = [ m for m in modes if m not in self.default_label_extract_modes["src"] ] self.label_extract_mode = modes[-1] if self.traces_mode == "vertex": _set_annot("None") else: _set_annot(self.annot) self.widgets["annotation"] = self._renderer._dock_add_combo_box( name="Annotation", value=self.annot, rng=cands, callback=_set_annot, layout=layout, ) self.widgets["extract_mode"] = self._renderer._dock_add_combo_box( name="Extract mode", value=self.label_extract_mode, rng=modes, callback=_set_label_mode, layout=layout, ) def _configure_dock(self): self._renderer._dock_initialize() self._configure_dock_playback_widget(name="Playback") self._configure_dock_orientation_widget(name="Orientation") self._configure_dock_colormap_widget(name="Color Limits") self._configure_dock_trace_widget(name="Trace") # Smoothing widget self.widgets["smoothing"] = self._renderer._dock_add_spin_box( name="Smoothing", value=self._data["smoothing_steps"], rng=self.default_smoothing_range, callback=self.set_data_smoothing, double=False, ) self._renderer._dock_finalize() def _configure_mplcanvas(self): # Get the fractional components for the brain and mpl self.mpl_canvas = self._renderer._window_get_mplcanvas( brain=self, interactor_fraction=self.interactor_fraction, show_traces=self.show_traces, separate_canvas=self.separate_canvas, ) xlim = [np.min(self._data["time"]), np.max(self._data["time"])] with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) self.mpl_canvas.axes.set(xlim=xlim) if not self.separate_canvas: self._renderer._window_adjust_mplcanvas_layout() self.mpl_canvas.set_color( bg_color=self._bg_color, fg_color=self._fg_color, ) def _configure_vertex_time_course(self): if not self.show_traces: return if self.mpl_canvas is None: self._configure_mplcanvas() else: self.clear_glyphs() # plot RMS of the activation y = np.concatenate( list(v[0] for v in self.act_data_smooth.values() if v[0] is not None) ) rms = np.linalg.norm(y, axis=0) / np.sqrt(len(y)) del y (self.rms,) = self.mpl_canvas.axes.plot( self._data["time"], rms, lw=3, label="RMS", zorder=3, color=self._fg_color, alpha=0.5, ls=":", ) # now plot the time line self.plot_time_line(update=False) # then the picked points for idx, hemi in enumerate(["lh", "rh", "vol"]): act_data = self.act_data_smooth.get(hemi, [None])[0] if act_data is None: continue hemi_data = self._data[hemi] vertices = hemi_data["vertices"] # simulate a picked renderer if self._hemi in ("both", "rh") or hemi == "vol": idx = 0 self.picked_renderer = self._renderer._all_renderers[idx] # initialize the default point if self._data["initial_time"] is not None: # pick at that time use_data = act_data[:, [np.round(self._data["time_idx"]).astype(int)]] else: use_data = act_data ind = np.unravel_index( np.argmax(np.abs(use_data), axis=None), use_data.shape ) vertex_id = vertices[ind[0]] publish(self, VertexSelect(hemi=hemi, vertex_id=vertex_id)) def _configure_picking(self): # get data for each hemi for idx, hemi in enumerate(["vol", "lh", "rh"]): hemi_data = self._data.get(hemi) if hemi_data is not None: act_data = hemi_data["array"] if act_data.ndim == 3: act_data = np.linalg.norm(act_data, axis=1) smooth_mat = hemi_data.get("smooth_mat") vertices = hemi_data["vertices"] if hemi == "vol": assert smooth_mat is None smooth_mat = csr_array( (np.ones(len(vertices)), (vertices, np.arange(len(vertices)))) ) self.act_data_smooth[hemi] = (act_data, smooth_mat) self._renderer._update_picking_callback( self._on_mouse_move, self._on_button_press, self._on_button_release, self._on_pick, ) subscribe(self, "vertex_select", self._on_vertex_select) def _configure_tool_bar(self): if not hasattr(self._renderer, "_tool_bar") or self._renderer._tool_bar is None: self._renderer._tool_bar_initialize(name="Toolbar") @_auto_weakref def save_image(filename): self.save_image(filename) self._renderer._tool_bar_add_file_button( name="screenshot", desc="Take a screenshot", func=save_image, ) @_auto_weakref def save_movie(filename): self.save_movie( filename=filename, time_dilation=(1.0 / self.playback_speed) ) self._renderer._tool_bar_add_file_button( name="movie", desc="Save movie...", func=save_movie, shortcut="ctrl+shift+s", ) self._renderer._tool_bar_add_button( name="visibility", desc="Toggle Controls", func=self.toggle_interface, icon_name="visibility_on", ) self._renderer._tool_bar_add_button( name="scale", desc="Auto-Scale", func=self.apply_auto_scaling, ) self._renderer._tool_bar_add_button( name="clear", desc="Clear traces", func=self.clear_glyphs, ) self._renderer._tool_bar_add_spacer() self._renderer._tool_bar_add_button( name="help", desc="Help", func=self.help, shortcut="?", ) def _rotate_camera(self, which, value): _, _, azimuth, elevation, _ = self._renderer.get_camera(rigid=self._rigid) kwargs = dict(update=True) if which == "azimuth": value = azimuth + value # Our view_up threshold is 5/175, so let's be safe here if elevation < 7.5 or elevation > 172.5: kwargs["elevation"] = np.clip(elevation, 10, 170) else: value = np.clip(elevation + value, 10, 170) kwargs[which] = value self._set_camera(**kwargs) def _configure_shortcuts(self): # Remove the default key binding if getattr(self, "iren", None) is not None: self.plotter.iren.clear_key_event_callbacks() # Then, we add our own: self.plotter.add_key_event("i", self.toggle_interface) self.plotter.add_key_event("s", self.apply_auto_scaling) self.plotter.add_key_event("r", self.restore_user_scaling) self.plotter.add_key_event("c", self.clear_glyphs) for key, which, amt in ( ("Left", "azimuth", 10), ("Right", "azimuth", -10), ("Up", "elevation", 10), ("Down", "elevation", -10), ): self.plotter.clear_events_for_key(key) self.plotter.add_key_event(key, partial(self._rotate_camera, which, amt)) def _configure_menu(self): self._renderer._menu_initialize() self._renderer._menu_add_submenu( name="help", desc="Help", ) self._renderer._menu_add_button( menu_name="help", name="help", desc="Show MNE key bindings\t?", func=self.help, ) def _configure_status_bar(self): self._renderer._status_bar_initialize() self.status_msg = self._renderer._status_bar_add_label( self.default_status_bar_msg, stretch=1 ) self.status_progress = self._renderer._status_bar_add_progress_bar() if self.status_progress is not None: self.status_progress.hide() def _on_mouse_move(self, vtk_picker, event): if self._mouse_no_mvt: self._mouse_no_mvt -= 1 def _on_button_press(self, vtk_picker, event): self._mouse_no_mvt = 2 def _on_button_release(self, vtk_picker, event): if self._mouse_no_mvt > 0: x, y = vtk_picker.GetEventPosition() # programmatically detect the picked renderer try: # pyvista<0.30.0 self.picked_renderer = self.plotter.iren.FindPokedRenderer(x, y) except AttributeError: # pyvista>=0.30.0 self.picked_renderer = self.plotter.iren.interactor.FindPokedRenderer( x, y ) # trigger the pick self.plotter.picker.Pick(x, y, 0, self.picked_renderer) self._mouse_no_mvt = 0 def _on_pick(self, vtk_picker, event): if not self.show_traces: return # vtk_picker is a vtkCellPicker cell_id = vtk_picker.GetCellId() mesh = vtk_picker.GetDataSet() if mesh is None or cell_id == -1 or not self._mouse_no_mvt: return # don't pick # 1) Check to see if there are any spheres along the ray if len(self._spheres): collection = vtk_picker.GetProp3Ds() found_sphere = None for ii in range(collection.GetNumberOfItems()): actor = collection.GetItemAsObject(ii) for sphere in self._spheres: if any(a is actor for a in sphere._actors): found_sphere = sphere break if found_sphere is not None: break if found_sphere is not None: assert found_sphere._is_glyph mesh = found_sphere # 2) Remove sphere if it's what we have if hasattr(mesh, "_is_glyph"): self._remove_vertex_glyph(mesh) return # 3) Otherwise, pick the objects in the scene try: hemi = mesh._hemi except AttributeError: # volume hemi = "vol" else: assert hemi in ("lh", "rh") if self.act_data_smooth[hemi][0] is None: # no data to add for hemi return pos = np.array(vtk_picker.GetPickPosition()) if hemi == "vol": # VTK will give us the point closest to the viewer in the vol. # We want to pick the point with the maximum value along the # camera-to-click array, which fortunately we can get "just" # by inspecting the points that are sufficiently close to the # ray. grid = mesh = self._data[hemi]["grid"] vertices = self._data[hemi]["vertices"] coords = self._data[hemi]["grid_coords"][vertices] scalars = grid.cell_data["values"][vertices] spacing = np.array(grid.GetSpacing()) max_dist = np.linalg.norm(spacing) / 2.0 origin = vtk_picker.GetRenderer().GetActiveCamera().GetPosition() ori = pos - origin ori /= np.linalg.norm(ori) # the magic formula: distance from a ray to a given point dists = np.linalg.norm(np.cross(ori, coords - pos), axis=1) assert dists.shape == (len(coords),) mask = dists <= max_dist idx = np.where(mask)[0] if len(idx) == 0: return # weird point on edge of volume? # useful for debugging the ray by mapping it into the volume: # dists = dists - dists.min() # dists = (1. - dists / dists.max()) * self._cmap_range[1] # grid.cell_data['values'][vertices] = dists * mask idx = idx[np.argmax(np.abs(scalars[idx]))] vertex_id = vertices[idx] # Naive way: convert pos directly to idx; i.e., apply mri_src_t # shape = self._data[hemi]['grid_shape'] # taking into account the cell vs point difference (spacing/2) # shift = np.array(grid.GetOrigin()) + spacing / 2. # ijk = np.round((pos - shift) / spacing).astype(int) # vertex_id = np.ravel_multi_index(ijk, shape, order='F') else: vtk_cell = mesh.GetCell(cell_id) cell = [ vtk_cell.GetPointId(point_id) for point_id in range(vtk_cell.GetNumberOfPoints()) ] vertices = mesh.points[cell] idx = np.argmin(abs(vertices - pos), axis=0) vertex_id = cell[idx[0]] publish(self, VertexSelect(hemi=hemi, vertex_id=vertex_id)) def _on_time_change(self, event): """Respond to a time change UI event.""" if event.time == self._current_time: return time_idx = self._to_time_index(event.time) self._update_current_time_idx(time_idx) if self.time_viewer: with disable_ui_events(self): if "time" in self.widgets: self.widgets["time"].set_value(time_idx) if "current_time" in self.widgets: self.widgets["current_time"].set_value(f"{self._current_time: .3f}") self.plot_time_line(update=True) def _on_colormap_range(self, event): """Respond to the colormap_range UI event.""" if event.kind != "distributed_source_power": return lims = {key: getattr(event, key) for key in ("fmin", "fmid", "fmax", "alpha")} # Check if limits have changed at all. if all(val is None or val == self._data[key] for key, val in lims.items()): return # Update the GUI elements. with disable_ui_events(self): for key, val in lims.items(): if val is not None: if key in self.widgets: self.widgets[key].set_value(val * self._data["fscale"]) entry_key = "entry_" + key if entry_key in self.widgets: self.widgets[entry_key].set_value(val * self._data["fscale"]) # Update the render. self._update_colormap_range(**lims) def _on_vertex_select(self, event): """Respond to vertex_select UI event.""" if event.hemi == "vol": try: mesh = self._data[event.hemi]["grid"] except KeyError: return else: try: mesh = self._layered_meshes[event.hemi]._polydata except KeyError: return if self.traces_mode == "label": self._add_label_glyph(event.hemi, mesh, event.vertex_id) else: self._add_vertex_glyph(event.hemi, mesh, event.vertex_id) def _add_label_glyph(self, hemi, mesh, vertex_id): if hemi == "vol": return label_id = self._vertex_to_label_id[hemi][vertex_id] label = self._annotation_labels[hemi][label_id] # remove the patch if already picked if label_id in self.picked_patches[hemi]: self._remove_label_glyph(hemi, label_id) return if hemi == label.hemi: self.add_label(label, borders=True) self.picked_patches[hemi].append(label_id) def _remove_label_glyph(self, hemi, label_id): label = self._annotation_labels[hemi][label_id] label._line.remove() self.color_cycle.restore(label._color) self.mpl_canvas.update_plot() self._layered_meshes[hemi].remove_overlay(label.name) self.picked_patches[hemi].remove(label_id) def _add_vertex_glyph(self, hemi, mesh, vertex_id, update=True): if vertex_id in self.picked_points[hemi]: return # skip if the wrong hemi is selected if self.act_data_smooth[hemi][0] is None: return color = next(self.color_cycle) line = self.plot_time_course(hemi, vertex_id, color, update=update) if hemi == "vol": ijk = np.unravel_index( vertex_id, np.array(mesh.GetDimensions()) - 1, order="F" ) voxel = mesh.GetCell(*ijk) center = np.empty(3) voxel.GetCentroid(center) else: center = mesh.GetPoints().GetPoint(vertex_id) del mesh # from the picked renderer to the subplot coords try: lst = self._renderer._all_renderers._renderers except AttributeError: lst = self._renderer._all_renderers rindex = lst.index(self.picked_renderer) row, col = self._renderer._index_to_loc(rindex) actors = list() spheres = list() for _ in self._iter_views(hemi): # Using _sphere() instead of renderer.sphere() for 2 reasons: # 1) renderer.sphere() fails on Windows in a scenario where a lot # of picking requests are done in a short span of time (could be # mitigated with synchronization/delay?) # 2) the glyph filter is used in renderer.sphere() but only one # sphere is required in this function. actor, sphere = self._renderer._sphere( center=np.array(center), color=color, radius=4.0, ) actors.append(actor) spheres.append(sphere) # add metadata for picking for sphere in spheres: sphere._is_glyph = True sphere._hemi = hemi sphere._line = line sphere._actors = actors sphere._color = color sphere._vertex_id = vertex_id self.picked_points[hemi].append(vertex_id) self._spheres.extend(spheres) self.pick_table[vertex_id] = spheres return sphere def _remove_vertex_glyph(self, mesh, render=True): vertex_id = mesh._vertex_id if vertex_id not in self.pick_table: return hemi = mesh._hemi color = mesh._color spheres = self.pick_table[vertex_id] spheres[0]._line.remove() self.mpl_canvas.update_plot() self.picked_points[hemi].remove(vertex_id) with warnings.catch_warnings(record=True): # We intentionally ignore these in case we have traversed the # entire color cycle warnings.simplefilter("ignore") self.color_cycle.restore(color) for sphere in spheres: # remove all actors self.plotter.remove_actor(sphere._actors, render=False) sphere._actors = None self._spheres.pop(self._spheres.index(sphere)) if render: self._renderer._update() self.pick_table.pop(vertex_id) def clear_glyphs(self): """Clear the picking glyphs.""" if not self.time_viewer: return for sphere in list(self._spheres): # will remove itself, so copy self._remove_vertex_glyph(sphere, render=False) assert sum(len(v) for v in self.picked_points.values()) == 0 assert len(self.pick_table) == 0 assert len(self._spheres) == 0 for hemi in self._hemis: for label_id in list(self.picked_patches[hemi]): self._remove_label_glyph(hemi, label_id) assert sum(len(v) for v in self.picked_patches.values()) == 0 if self.rms is not None: self.rms.remove() self.rms = None self._renderer._update() @fill_doc def plot_time_course(self, hemi, vertex_id, color, update=True): """Plot the vertex time course. Parameters ---------- hemi : str The hemisphere id of the vertex. vertex_id : int The vertex identifier in the mesh. color : matplotlib color The color of the time course. %(brain_update)s Returns ------- line : matplotlib object The time line object. """ if self.mpl_canvas is None: return time = self._data["time"].copy() # avoid circular ref mni = None if hemi == "vol": hemi_str = "V" xfm = read_talxfm(self._subject, self._subjects_dir) if self._units == "mm": xfm["trans"][:3, 3] *= 1000.0 ijk = np.unravel_index(vertex_id, self._data[hemi]["grid_shape"], order="F") src_mri_t = self._data[hemi]["grid_src_mri_t"] mni = apply_trans(xfm["trans"] @ src_mri_t, ijk) else: hemi_str = "L" if hemi == "lh" else "R" try: mni = vertex_to_mni( vertices=vertex_id, hemis=0 if hemi == "lh" else 1, subject=self._subject, subjects_dir=self._subjects_dir, ) except Exception: mni = None if mni is not None: mni = " MNI: " + ", ".join(f"{m:5.1f}" for m in mni) else: mni = "" label = f"{hemi_str}:{str(vertex_id).ljust(6)}{mni}" act_data, smooth = self.act_data_smooth[hemi] if smooth is not None: act_data = (smooth[[vertex_id]] @ act_data)[0] else: act_data = act_data[vertex_id].copy() line = self.mpl_canvas.plot( time, act_data, label=label, lw=1.0, color=color, zorder=4, update=update, ) return line @fill_doc def plot_time_line(self, update=True): """Add the time line to the MPL widget. Parameters ---------- %(brain_update)s """ if self.mpl_canvas is None: return if isinstance(self.show_traces, bool) and self.show_traces: # add time information current_time = self._current_time if not hasattr(self, "time_line"): self.time_line = self.mpl_canvas.plot_time_line( x=current_time, label="time", color=self._fg_color, lw=1, update=update, ) self.time_line.set_xdata([current_time]) if update: self.mpl_canvas.update_plot() def _configure_help(self): pairs = [ ("?", "Display help window"), ("i", "Toggle interface"), ("s", "Apply auto-scaling"), ("r", "Restore original clim"), ("c", "Clear all traces"), ("n", "Shift the time forward by the playback speed"), ("b", "Shift the time backward by the playback speed"), ("Space", "Start/Pause playback"), ("Up", "Decrease camera elevation angle"), ("Down", "Increase camera elevation angle"), ("Left", "Decrease camera azimuth angle"), ("Right", "Increase camera azimuth angle"), ] text1, text2 = zip(*pairs) text1 = "\n".join(text1) text2 = "\n".join(text2) self.help_canvas = self._renderer._window_get_simple_canvas( width=5, height=2, dpi=80 ) _show_help_fig( col1=text1, col2=text2, fig_help=self.help_canvas.fig, ax=self.help_canvas.axes, show=False, ) def help(self): """Display the help window.""" self.help_canvas.show() def _clear_callbacks(self): # Remove the default key binding if getattr(self, "iren", None) is not None: self.plotter.iren.clear_key_event_callbacks() def _clear_widgets(self): if not hasattr(self, "widgets"): return for widget in self.widgets.values(): if widget is not None: for key in ("triggered", "floatValueChanged"): setattr(widget, key, None) self.widgets.clear() @property def interaction(self): """The interaction style.""" return self._interaction @interaction.setter def interaction(self, interaction): """Set the interaction style.""" _validate_type(interaction, str, "interaction") _check_option("interaction", interaction, ("trackball", "terrain")) for _ in self._iter_views("vol"): # will traverse all self._renderer.set_interaction(interaction) def _cortex_colormap(self, cortex): """Return the colormap corresponding to the cortex.""" from matplotlib.colors import ListedColormap from .._3d import _get_cmap colormap_map = dict( classic=dict(colormap="Greys", vmin=-1, vmax=2), high_contrast=dict(colormap="Greys", vmin=-0.1, vmax=1.3), low_contrast=dict(colormap="Greys", vmin=-5, vmax=5), bone=dict(colormap="bone_r", vmin=-0.2, vmax=2), ) _validate_type(cortex, (str, dict, list, tuple), "cortex") if isinstance(cortex, str): if cortex in colormap_map: cortex = colormap_map[cortex] else: cortex = [cortex] * 2 if isinstance(cortex, (list, tuple)): _check_option( "len(cortex)", len(cortex), (2, 3), extra="when cortex is a list or tuple", ) if len(cortex) == 3: cortex = [cortex] * 2 cortex = list(cortex) for ci, c in enumerate(cortex): cortex[ci] = _to_rgb(c, name="cortex") cortex = dict( colormap=ListedColormap(cortex, name="custom binary"), vmin=0, vmax=1 ) cortex = dict( vmin=float(cortex["vmin"]), vmax=float(cortex["vmax"]), colormap=_get_cmap(cortex["colormap"]), ) return cortex def _remove(self, item, render=False): """Remove actors from the rendered scene.""" if item in self._actors: logger.debug(f"Removing {len(self._actors[item])} {item} actor(s)") for actor in self._actors[item]: self._renderer.plotter.remove_actor(actor, render=False) self._actors.pop(item) # remove actor list if render: self._renderer._update() def _add_actor(self, item, actor): """Add an actor to the internal register.""" if item in self._actors: # allows adding more than one self._actors[item].append(actor) else: self._actors[item] = [actor] @verbose def add_data( self, array, fmin=None, fmid=None, fmax=None, thresh=None, center=None, transparent=False, colormap="auto", alpha=1, vertices=None, smoothing_steps=None, time=None, time_label="auto", colorbar=True, hemi=None, remove_existing=None, time_label_size=None, initial_time=None, scale_factor=None, vector_alpha=None, clim=None, src=None, volume_options=0.4, colorbar_kwargs=None, verbose=None, ): """Display data from a numpy array on the surface or volume. This provides a similar interface to PySurfer, but it displays it with a single colormap. It offers more flexibility over the colormap, and provides a way to display four-dimensional data (i.e., a timecourse) or five-dimensional data (i.e., a vector-valued timecourse). .. note:: ``fmin`` sets the low end of the colormap, and is separate from thresh (this is a different convention from PySurfer). Parameters ---------- array : numpy array, shape (n_vertices[, 3][, n_times]) Data array. For the data to be understood as vector-valued (3 values per vertex corresponding to X/Y/Z surface RAS), then ``array`` must be have all 3 dimensions. If vectors with no time dimension are desired, consider using a singleton (e.g., ``np.newaxis``) to create a "time" dimension and pass ``time_label=None`` (vector values are not supported). %(fmin_fmid_fmax)s %(thresh)s %(center)s %(transparent)s colormap : str, list of color, or array Name of matplotlib colormap to use, a list of matplotlib colors, or a custom look up table (an n x 4 array coded with RBGA values between 0 and 255), the default "auto" chooses a default divergent colormap, if "center" is given (currently "icefire"), otherwise a default sequential colormap (currently "rocket"). alpha : float in [0, 1] Alpha level to control opacity of the overlay. vertices : numpy array Vertices for which the data is defined (needed if ``len(data) < nvtx``). smoothing_steps : int or None Number of smoothing steps (smoothing is used if len(data) < nvtx) The value 'nearest' can be used too. None (default) will use as many as necessary to fill the surface. time : numpy array Time points in the data array (if data is 2D or 3D). %(time_label)s colorbar : bool Whether to add a colorbar to the figure. Can also be a tuple to give the (row, col) index of where to put the colorbar. hemi : str | None If None, it is assumed to belong to the hemisphere being shown. If two hemispheres are being shown, an error will be thrown. remove_existing : bool Not supported yet. Remove surface added by previous "add_data" call. Useful for conserving memory when displaying different data in a loop. time_label_size : int Font size of the time label (default 14). initial_time : float | None Time initially shown in the plot. ``None`` to use the first time sample (default). scale_factor : float | None (default) The scale factor to use when displaying glyphs for vector-valued data. vector_alpha : float | None Alpha level to control opacity of the arrows. Only used for vector-valued data. If None (default), ``alpha`` is used. clim : dict Original clim arguments. %(src_volume_options)s colorbar_kwargs : dict | None Options to pass to ``pyvista.Plotter.add_scalar_bar`` (e.g., ``dict(title_font_size=10)``). %(verbose)s Notes ----- If the data is defined for a subset of vertices (specified by the "vertices" parameter), a smoothing method is used to interpolate the data onto the high resolution surface. If the data is defined for subsampled version of the surface, smoothing_steps can be set to None, in which case only as many smoothing steps are applied until the whole surface is filled with non-zeros. Due to a VTK alpha rendering bug, ``vector_alpha`` is clamped to be strictly < 1. """ _validate_type(transparent, bool, "transparent") _validate_type(vector_alpha, ("numeric", None), "vector_alpha") _validate_type(scale_factor, ("numeric", None), "scale_factor") # those parameters are not supported yet, only None is allowed _check_option("thresh", thresh, [None]) _check_option("remove_existing", remove_existing, [None]) _validate_type(time_label_size, (None, "numeric"), "time_label_size") if time_label_size is not None: time_label_size = float(time_label_size) if time_label_size < 0: raise ValueError( f"time_label_size must be positive, got {time_label_size}" ) hemi = self._check_hemi(hemi, extras=["vol"]) stc, array, vertices = self._check_stc(hemi, array, vertices) array = np.asarray(array) vector_alpha = alpha if vector_alpha is None else vector_alpha self._data["vector_alpha"] = vector_alpha self._data["scale_factor"] = scale_factor # Create time array and add label if > 1D if array.ndim <= 1: time_idx = 0 else: # check time array if time is None: time = np.arange(array.shape[-1]) else: time = np.asarray(time) if time.shape != (array.shape[-1],): raise ValueError( f"time has shape {time.shape}, but need shape " f"{(array.shape[-1],)} (array.shape[-1])" ) self._data["time"] = time if self._n_times is None: self._times = time elif len(time) != self._n_times: raise ValueError("New n_times is different from previous n_times") elif not np.array_equal(time, self._times): raise ValueError( "Not all time values are consistent with previously set times." ) # initial time if initial_time is None: time_idx = 0 else: time_idx = self._to_time_index(initial_time) # time label time_label, _ = _handle_time(time_label, "s", time) y_txt = 0.05 + 0.1 * bool(colorbar) if array.ndim == 3: if array.shape[1] != 3: raise ValueError( "If array has 3 dimensions, array.shape[1] must equal 3, got " f"{array.shape[1]}" ) fmin, fmid, fmax = _update_limits(fmin, fmid, fmax, center, array) if colormap == "auto": colormap = "mne" if center is not None else "hot" if smoothing_steps is None: smoothing_steps = 7 elif smoothing_steps == "nearest": smoothing_steps = -1 elif isinstance(smoothing_steps, int): if smoothing_steps < 0: raise ValueError( "Expected value of `smoothing_steps` is positive but " f"{smoothing_steps} was given." ) else: raise TypeError( "Expected type of `smoothing_steps` is int or NoneType but " f"{type(smoothing_steps)} was given." ) self._data["stc"] = stc self._data["src"] = src self._data["smoothing_steps"] = smoothing_steps self._data["clim"] = clim self._data["time"] = time self._data["initial_time"] = initial_time self._data["time_label"] = time_label self._data["initial_time_idx"] = time_idx self._data["time_idx"] = time_idx self._data["transparent"] = transparent # data specific for a hemi self._data[hemi] = dict() self._data[hemi]["glyph_dataset"] = None self._data[hemi]["glyph_mapper"] = None self._data[hemi]["glyph_actor"] = None self._data[hemi]["array"] = array self._data[hemi]["vertices"] = vertices self._data["alpha"] = alpha self._data["colormap"] = colormap self._data["center"] = center self._data["fmin"] = fmin self._data["fmid"] = fmid self._data["fmax"] = fmax self._update_colormap_range() # 1) add the surfaces first actor = None for _ in self._iter_views(hemi): if hemi in ("lh", "rh"): actor = self._layered_meshes[hemi]._actor else: src_vol = src[2:] if src.kind == "mixed" else src actor, _ = self._add_volume_data(hemi, src_vol, volume_options) assert actor is not None # should have added one self._add_actor("data", actor) # 2) update time and smoothing properties # set_data_smoothing calls "_update_current_time_idx" for us, which will set # _current_time self.set_time_interpolation(self.time_interpolation) self.set_data_smoothing(self._data["smoothing_steps"]) # 3) add the other actors if colorbar is True: # bottom left by default colorbar = (self._subplot_shape[0] - 1, 0) for ri, ci, v in self._iter_views(hemi): # Add the time label to the bottommost view do = (ri, ci) == colorbar if not self._time_label_added and time_label is not None and do: time_actor = self._renderer.text2d( x_window=0.95, y_window=y_txt, color=self._fg_color, size=time_label_size, text=time_label(self._current_time), justification="right", ) self._data["time_actor"] = time_actor self._time_label_added = True if colorbar and self._scalar_bar is None and do: kwargs = dict( source=actor, n_labels=8, color=self._fg_color, bgcolor=self._brain_color[:3], ) kwargs.update(colorbar_kwargs or {}) self._scalar_bar = self._renderer.scalarbar(**kwargs) self._set_camera(**views_dicts[hemi][v]) # 4) update the scalar bar and opacity (and render) self._update_colormap_range(alpha=alpha) # 5) enable UI events to interact with the data subscribe(self, "colormap_range", self._on_colormap_range) if time is not None and len(time) > 1: subscribe(self, "time_change", self._on_time_change) def remove_data(self): """Remove rendered data from the mesh.""" self._remove("data", render=True) # Stop listening to events if "time_change" in _get_event_channel(self): unsubscribe(self, "time_change") def _iter_views(self, hemi): """Iterate over rows and columns that need to be added to.""" hemi_dict = dict(lh=[0], rh=[0], vol=[0]) if self._hemi == "split": hemi_dict.update(rh=[1], vol=[0, 1]) for vi, view in enumerate(self._views): view_dict = dict(lh=[vi], rh=[vi], vol=[vi]) if self._hemi == "split": view_dict.update(vol=[vi, vi]) if self._view_layout == "vertical": rows, cols = view_dict, hemi_dict # views are rows, hemis cols else: rows, cols = hemi_dict, view_dict # hemis are rows, views cols for ri, ci in zip(rows[hemi], cols[hemi]): self._renderer.subplot(ri, ci) yield ri, ci, view def remove_labels(self): """Remove all the ROI labels from the image.""" for hemi in self._hemis: mesh = self._layered_meshes[hemi] for label in self._labels[hemi]: mesh.remove_overlay(label.name) self._labels[hemi].clear() self._renderer._update() def remove_annotations(self): """Remove all annotations from the image.""" for hemi in self._hemis: if hemi in self._layered_meshes: mesh = self._layered_meshes[hemi] mesh.remove_overlay(self._annots[hemi]) if hemi in self._annots: self._annots[hemi].clear() self._renderer._update() def _add_volume_data(self, hemi, src, volume_options): from ...source_space import SourceSpaces from ..backends._pyvista import _hide_testing_actor _validate_type(src, SourceSpaces, "src") _check_option("src.kind", src.kind, ("volume",)) _validate_type(volume_options, (dict, "numeric", None), "volume_options") assert hemi == "vol" if not isinstance(volume_options, dict): volume_options = dict( resolution=float(volume_options) if volume_options is not None else None ) volume_options = _handle_default("volume_options", volume_options) allowed_types = ( ["resolution", (None, "numeric")], ["blending", (str,)], ["alpha", ("numeric", None)], ["surface_alpha", (None, "numeric")], ["silhouette_alpha", (None, "numeric")], ["silhouette_linewidth", ("numeric",)], ) for key, types in allowed_types: _validate_type(volume_options[key], types, f"volume_options[{repr(key)}]") extra_keys = set(volume_options) - set(a[0] for a in allowed_types) if len(extra_keys): raise ValueError(f"volume_options got unknown keys {sorted(extra_keys)}") blending = _check_option( 'volume_options["blending"]', volume_options["blending"], ("composite", "mip"), ) alpha = volume_options["alpha"] if alpha is None: alpha = 0.4 if self._data[hemi]["array"].ndim == 3 else 1.0 alpha = np.clip(float(alpha), 0.0, 1.0) resolution = volume_options["resolution"] surface_alpha = volume_options["surface_alpha"] if surface_alpha is None: surface_alpha = min(alpha / 2.0, 0.1) silhouette_alpha = volume_options["silhouette_alpha"] if silhouette_alpha is None: silhouette_alpha = surface_alpha / 4.0 silhouette_linewidth = volume_options["silhouette_linewidth"] del volume_options volume_pos = self._data[hemi].get("grid_volume_pos") volume_neg = self._data[hemi].get("grid_volume_neg") center = self._data["center"] if volume_pos is None: xyz = np.meshgrid(*[np.arange(s) for s in src[0]["shape"]], indexing="ij") dimensions = np.array(src[0]["shape"], int) mult = 1000 if self._units == "mm" else 1 src_mri_t = src[0]["src_mri_t"]["trans"].copy() src_mri_t[:3] *= mult if resolution is not None: resolution = resolution * mult / 1000.0 # to mm del src, mult coords = np.array([c.ravel(order="F") for c in xyz]).T coords = apply_trans(src_mri_t, coords) self.geo[hemi] = Bunch(coords=coords) vertices = self._data[hemi]["vertices"] assert self._data[hemi]["array"].shape[0] == len(vertices) # MNE constructs the source space on a uniform grid in MRI space, # but mne coreg can change it to be non-uniform, so we need to # use all three elements here assert np.allclose(src_mri_t[:3, :3], np.diag(np.diag(src_mri_t)[:3])) spacing = np.diag(src_mri_t)[:3] origin = src_mri_t[:3, 3] - spacing / 2.0 scalars = np.zeros(np.prod(dimensions)) scalars[vertices] = 1.0 # for the outer mesh grid, grid_mesh, volume_pos, volume_neg = self._renderer._volume( dimensions, origin, spacing, scalars, surface_alpha, resolution, blending, center, ) self._data[hemi]["alpha"] = alpha # incorrectly set earlier self._data[hemi]["grid"] = grid self._data[hemi]["grid_mesh"] = grid_mesh self._data[hemi]["grid_coords"] = coords self._data[hemi]["grid_src_mri_t"] = src_mri_t self._data[hemi]["grid_shape"] = dimensions self._data[hemi]["grid_volume_pos"] = volume_pos self._data[hemi]["grid_volume_neg"] = volume_neg actor_pos, _ = self._renderer.plotter.add_actor( volume_pos, name=None, culling=False, reset_camera=False, render=False ) actor_neg = actor_mesh = None if volume_neg is not None: actor_neg, _ = self._renderer.plotter.add_actor( volume_neg, name=None, culling=False, reset_camera=False, render=False ) grid_mesh = self._data[hemi]["grid_mesh"] if grid_mesh is not None: actor_mesh, prop = self._renderer.plotter.add_actor( grid_mesh, name=None, culling=False, pickable=False, reset_camera=False, render=False, ) prop.SetColor(*self._brain_color[:3]) prop.SetOpacity(surface_alpha) if silhouette_alpha > 0 and silhouette_linewidth > 0: for _ in self._iter_views("vol"): self._renderer._silhouette( mesh=grid_mesh.GetInput(), color=self._brain_color[:3], line_width=silhouette_linewidth, alpha=silhouette_alpha, ) for actor in (actor_pos, actor_neg, actor_mesh): if actor is not None: _hide_testing_actor(actor) return actor_pos, actor_neg def add_label( self, label, color=None, alpha=1, scalar_thresh=None, borders=False, hemi=None, subdir=None, ): """Add an ROI label to the image. Parameters ---------- label : str | instance of Label Label filepath or name. Can also be an instance of an object with attributes "hemi", "vertices", "name", and optionally "color" and "values" (if scalar_thresh is not None). color : matplotlib-style color | None Anything matplotlib accepts: string, RGB, hex, etc. (default "crimson"). alpha : float in [0, 1] Alpha level to control opacity. scalar_thresh : None | float Threshold the label ids using this value in the label file's scalar field (i.e. label only vertices with scalar >= thresh). borders : bool | int Show only label borders. If int, specify the number of steps (away from the true border) along the cortical mesh to include as part of the border definition. hemi : str | None If None, it is assumed to belong to the hemisphere being shown. subdir : None | str If a label is specified as name, subdir can be used to indicate that the label file is in a sub-directory of the subject's label directory rather than in the label directory itself (e.g. for ``$SUBJECTS_DIR/$SUBJECT/label/aparc/lh.cuneus.label`` ``brain.add_label('cuneus', subdir='aparc')``). Notes ----- To remove previously added labels, run Brain.remove_labels(). """ from ...label import read_label if isinstance(label, str): if color is None: color = "crimson" if os.path.isfile(label): filepath = label label = read_label(filepath) hemi = label.hemi label_name = os.path.basename(filepath).split(".")[1] else: hemi = self._check_hemi(hemi) label_name = label label_fname = ".".join([hemi, label_name, "label"]) if subdir is None: filepath = op.join( self._subjects_dir, self._subject, "label", label_fname ) else: filepath = op.join( self._subjects_dir, self._subject, "label", subdir, label_fname ) if not os.path.exists(filepath): raise ValueError(f"Label file {filepath} does not exist") label = read_label(filepath) ids = label.vertices scalars = label.values else: # try to extract parameters from label instance try: hemi = label.hemi ids = label.vertices if label.name is None: label.name = "unnamed" + str(self._unnamed_label_id) self._unnamed_label_id += 1 label_name = str(label.name) if color is None: if hasattr(label, "color") and label.color is not None: color = label.color else: color = "crimson" if scalar_thresh is not None: scalars = label.values except Exception: raise ValueError( "Label was not a filename (str), and could " "not be understood as a class. The class " 'must have attributes "hemi", "vertices", ' '"name", and (if scalar_thresh is not None)' '"values"' ) hemi = self._check_hemi(hemi) if scalar_thresh is not None: ids = ids[scalars >= scalar_thresh] if self.time_viewer and self.show_traces and self.traces_mode == "label": stc = self._data["stc"] src = self._data["src"] tc = stc.extract_label_time_course( label, src=src, mode=self.label_extract_mode ) tc = tc[0] if tc.ndim == 2 else tc[0, 0, :] color = next(self.color_cycle) line = self.mpl_canvas.plot( self._data["time"], tc, label=label_name, color=color ) else: line = None orig_color = color color = _to_rgb(color, alpha, alpha=True) cmap = np.array( [ ( 0, 0, 0, 0, ), color, ] ) ctable = np.round(cmap * 255).astype(np.uint8) scalars = np.zeros(self.geo[hemi].coords.shape[0]) scalars[ids] = 1 if borders: keep_idx = _mesh_borders(self.geo[hemi].faces, scalars) show = np.zeros(scalars.size, dtype=np.int64) if isinstance(borders, int): for _ in range(borders): keep_idx = np.isin(self.geo[hemi].faces.ravel(), keep_idx) keep_idx.shape = self.geo[hemi].faces.shape keep_idx = self.geo[hemi].faces[np.any(keep_idx, axis=1)] keep_idx = np.unique(keep_idx) show[keep_idx] = 1 scalars *= show for _, _, v in self._iter_views(hemi): mesh = self._layered_meshes[hemi] mesh.add_overlay( scalars=scalars, colormap=ctable, rng=[np.min(scalars), np.max(scalars)], opacity=alpha, name=label_name, ) if self.time_viewer and self.show_traces and self.traces_mode == "label": label._color = orig_color label._line = line self._labels[hemi].append(label) self._renderer._update() @fill_doc def add_forward(self, fwd, trans, alpha=1, scale=None): """Add a quiver to render positions of dipoles. Parameters ---------- %(fwd)s %(trans_not_none)s %(alpha)s Default 1. scale : None | float The size of the arrow representing the dipoles in :class:`mne.viz.Brain` units. Default 1.5mm. Notes ----- .. versionadded:: 1.0 """ head_mri_t = _get_trans(trans, "head", "mri", allow_none=False)[0] del trans if scale is None: scale = 1.5 if self._units == "mm" else 1.5e-3 error_msg = ( 'Unexpected forward model coordinate frame {}, must be "head" or "mri"' ) if fwd["coord_frame"] in _frame_to_str: fwd_frame = _frame_to_str[fwd["coord_frame"]] if fwd_frame == "mri": fwd_trans = Transform("mri", "mri") elif fwd_frame == "head": fwd_trans = head_mri_t else: raise RuntimeError(error_msg.format(fwd_frame)) else: raise RuntimeError(error_msg.format(fwd["coord_frame"])) for actor in _plot_forward( self._renderer, fwd, fwd_trans, fwd_scale=1e3 if self._units == "mm" else 1, scale=scale, alpha=alpha, ): self._add_actor("forward", actor) self._renderer._update() def remove_forward(self): """Remove forward sources from the rendered scene.""" self._remove("forward", render=True) @fill_doc def add_dipole(self, dipole, trans, colors="red", alpha=1, scales=None): """Add a quiver to render positions of dipoles. Parameters ---------- dipole : instance of Dipole Dipole object containing position, orientation and amplitude of one or more dipoles or in the forward solution. %(trans_not_none)s colors : list | matplotlib-style color | None A single color or list of anything matplotlib accepts: string, RGB, hex, etc. Default red. %(alpha)s Default 1. scales : list | float | None The size of the arrow representing the dipole in :class:`mne.viz.Brain` units. Default 5mm. Notes ----- .. versionadded:: 1.0 """ head_mri_t = _get_trans(trans, "head", "mri", allow_none=False)[0] del trans n_dipoles = len(dipole) if not isinstance(colors, (list, tuple)): colors = [colors] * n_dipoles # make into list if len(colors) != n_dipoles: raise ValueError( f"The number of colors ({len(colors)}) " f"and dipoles ({n_dipoles}) must match" ) colors = [ _to_rgb(color, name=f"colors[{ci}]") for ci, color in enumerate(colors) ] if scales is None: scales = 5 if self._units == "mm" else 5e-3 if not isinstance(scales, (list, tuple)): scales = [scales] * n_dipoles # make into list if len(scales) != n_dipoles: raise ValueError( f"The number of scales ({len(scales)}) " f"and dipoles ({n_dipoles}) must match" ) pos = apply_trans(head_mri_t, dipole.pos) pos *= 1e3 if self._units == "mm" else 1 for _ in self._iter_views("vol"): for this_pos, this_ori, color, scale in zip( pos, dipole.ori, colors, scales ): actor, _ = self._renderer.quiver3d( *this_pos, *this_ori, color=color, opacity=alpha, mode="arrow", scale=scale, ) self._add_actor("dipole", actor) self._renderer._update() def remove_dipole(self): """Remove dipole objects from the rendered scene.""" self._remove("dipole", render=True) @fill_doc def add_head(self, dense=True, color="gray", alpha=0.5): """Add a mesh to render the outer head surface. Parameters ---------- dense : bool Whether to plot the dense head (``seghead``) or the less dense head (``head``). %(color_matplotlib)s %(alpha)s Notes ----- .. versionadded:: 0.24 """ # load head surf = _get_head_surface( "seghead" if dense else "head", self._subject, self._subjects_dir ) verts, triangles = surf["rr"], surf["tris"] verts *= 1e3 if self._units == "mm" else 1 color = _to_rgb(color) for _ in self._iter_views("vol"): actor, _ = self._renderer.mesh( *verts.T, triangles=triangles, color=color, opacity=alpha, render=False, ) self._add_actor("head", actor) self._renderer._update() def remove_head(self): """Remove head objects from the rendered scene.""" self._remove("head", render=True) @fill_doc def add_skull(self, outer=True, color="gray", alpha=0.5): """Add a mesh to render the skull surface. Parameters ---------- outer : bool Adds the outer skull if ``True``, otherwise adds the inner skull. %(color_matplotlib)s %(alpha)s Notes ----- .. versionadded:: 0.24 """ surf = _get_skull_surface( "outer" if outer else "inner", self._subject, self._subjects_dir ) verts, triangles = surf["rr"], surf["tris"] verts *= 1e3 if self._units == "mm" else 1 color = _to_rgb(color) for _ in self._iter_views("vol"): actor, _ = self._renderer.mesh( *verts.T, triangles=triangles, color=color, opacity=alpha, reset_camera=False, render=False, ) self._add_actor("skull", actor) self._renderer._update() def remove_skull(self): """Remove skull objects from the rendered scene.""" self._remove("skull", render=True) @fill_doc def add_volume_labels( self, aseg="auto", labels=None, colors=None, alpha=0.5, smooth=0.9, fill_hole_size=None, legend=None, ): """Add labels to the rendering from an anatomical segmentation. Parameters ---------- %(aseg)s labels : list Labeled regions of interest to plot. See :func:`mne.get_montage_volume_labels` for one way to determine regions of interest. Regions can also be chosen from the :term:`FreeSurfer LUT`. colors : list | matplotlib-style color | None A list of anything matplotlib accepts: string, RGB, hex, etc. (default :term:`FreeSurfer LUT` colors). %(alpha)s %(smooth)s fill_hole_size : int | None The size of holes to remove in the mesh in voxels. Default is None, no holes are removed. Warning, this dilates the boundaries of the surface by ``fill_hole_size`` number of voxels so use the minimal size. legend : bool | None | dict Add a legend displaying the names of the ``labels``. Default (None) is ``True`` if the number of ``labels`` is 10 or fewer. Can also be a dict of ``kwargs`` to pass to ``pyvista.Plotter.add_legend``. Notes ----- .. versionadded:: 0.24 """ aseg, aseg_data = _get_aseg(aseg, self._subject, self._subjects_dir) vox_mri_t = aseg.header.get_vox2ras_tkr() mult = 1e-3 if self._units == "m" else 1 vox_mri_t[:3] *= mult del aseg # read freesurfer lookup table lut, fs_colors = read_freesurfer_lut() if labels is None: # assign default ROI labels based on indices lut_r = {v: k for k, v in lut.items()} labels = [lut_r[idx] for idx in DEFAULTS["volume_label_indices"]] _validate_type(fill_hole_size, (int, None), "fill_hole_size") _validate_type(legend, (bool, None, dict), "legend") if legend is None: legend = len(labels) < 11 if colors is None: colors = [fs_colors[label] / 255 for label in labels] elif not isinstance(colors, (list, tuple)): colors = [colors] * len(labels) # make into list colors = [ _to_rgb(color, name=f"colors[{ci}]") for ci, color in enumerate(colors) ] surfs = _marching_cubes( aseg_data, [lut[label] for label in labels], smooth=smooth, fill_hole_size=fill_hole_size, ) for label, color, (verts, triangles) in zip(labels, colors, surfs): if len(verts) == 0: # not in aseg vals warn( f"Value {lut[label]} not found for label " f"{repr(label)} in anatomical segmentation file " ) continue verts = apply_trans(vox_mri_t, verts) for _ in self._iter_views("vol"): actor, _ = self._renderer.mesh( *verts.T, triangles=triangles, color=color, opacity=alpha, reset_camera=False, render=False, ) self._add_actor("volume_labels", actor) if legend or isinstance(legend, dict): # use empty kwargs for legend = True legend = legend if isinstance(legend, dict) else dict() self._renderer.plotter.add_legend(list(zip(labels, colors)), **legend) self._renderer._update() def remove_volume_labels(self): """Remove the volume labels from the rendered scene.""" self._remove("volume_labels", render=True) self._renderer.plotter.remove_legend() @fill_doc def add_foci( self, coords, coords_as_verts=False, map_surface=None, scale_factor=1, color="white", alpha=1, name=None, hemi=None, resolution=50, ): """Add spherical foci, possibly mapping to displayed surf. The foci spheres can be displayed at the coordinates given, or mapped through a surface geometry. In other words, coordinates from a volume-based analysis in MNI space can be displayed on an inflated average surface by finding the closest vertex on the white surface and mapping to that vertex on the inflated mesh. Parameters ---------- coords : ndarray, shape (n_coords, 3) Coordinates in stereotaxic space (default) or array of vertex ids (with ``coord_as_verts=True``). coords_as_verts : bool Whether the coords parameter should be interpreted as vertex ids. map_surface : str | None Surface to project the coordinates to, or None to use raw coords. When set to a surface, each foci is positioned at the closest vertex in the mesh. scale_factor : float Controls the size of the foci spheres (relative to 1cm). %(color_matplotlib)s %(alpha)s Default is 1. name : str Internal name to use. hemi : str | None If None, it is assumed to belong to the hemisphere being shown. If two hemispheres are being shown, an error will be thrown. resolution : int The resolution of the spheres. """ hemi = self._check_hemi(hemi, extras=["vol"]) # Figure out how to interpret the first parameter if coords_as_verts: coords = self.geo[hemi].coords[coords] map_surface = None # Possibly map the foci coords through a surface if map_surface is not None: foci_surf = _Surface( self._subject, hemi, map_surface, self._subjects_dir, offset=0, units=self._units, x_dir=self._rigid[0, :3], ) foci_surf.load_geometry() foci_vtxs = np.argmin(cdist(foci_surf.coords, coords), axis=0) coords = self.geo[hemi].coords[foci_vtxs] # Convert the color code color = _to_rgb(color) if self._units == "m": scale_factor = scale_factor / 1000.0 for _, _, v in self._iter_views(hemi): self._renderer.sphere( center=coords, color=color, scale=(10.0 * scale_factor), opacity=alpha, resolution=resolution, ) self._set_camera(**views_dicts[hemi][v]) self._renderer._update() # Store the foci in the Brain._data dictionary data_foci = coords if "foci" in self._data.get(hemi, []): data_foci = np.vstack((self._data[hemi]["foci"], data_foci)) self._data[hemi] = self._data.get(hemi, dict()) # no data added yet self._data[hemi]["foci"] = data_foci @verbose def add_sensors( self, info, trans, meg=None, eeg="original", fnirs=True, ecog=True, seeg=True, dbs=True, max_dist=0.004, *, sensor_colors=None, verbose=None, ): """Add mesh objects to represent sensor positions. Parameters ---------- %(info_not_none)s %(trans_not_none)s %(meg)s %(eeg)s %(fnirs)s %(ecog)s %(seeg)s %(dbs)s %(max_dist_ieeg)s %(sensor_colors)s .. versionadded:: 1.6 %(verbose)s Notes ----- .. versionadded:: 0.24 """ from ...preprocessing.ieeg._projection import _project_sensors_onto_inflated _validate_type(info, Info, "info") meg, eeg, fnirs, warn_meg, sensor_alpha = _handle_sensor_types(meg, eeg, fnirs) picks = pick_types( info, meg=("sensors" in meg), ref_meg=("ref" in meg), eeg=(len(eeg) > 0), ecog=ecog, seeg=seeg, dbs=dbs, fnirs=(len(fnirs) > 0), ) head_mri_t = _get_trans(trans, "head", "mri", allow_none=False)[0] if self._surf in ("inflated", "flat"): for modality, check in dict(seeg=seeg, ecog=ecog).items(): if pick_types(info, **{modality: check}).size > 0: info = _project_sensors_onto_inflated( info.copy(), head_mri_t, subject=self._subject, subjects_dir=self._subjects_dir, picks=modality, max_dist=max_dist, flat=self._surf == "flat", ) del trans # get transforms to "mri" window to_cf_t = _get_transforms_to_coord_frame(info, head_mri_t, coord_frame="mri") if pick_types(info, eeg=True, exclude=()).size > 0 and "projected" in eeg: head_surf = _get_head_surface("seghead", self._subject, self._subjects_dir) else: head_surf = None # Do the main plotting for _ in self._iter_views("vol"): if picks.size > 0: sensors_actors = _plot_sensors_3d( self._renderer, info, to_cf_t, picks, meg, eeg, fnirs, warn_meg, head_surf, self._units, sensor_alpha=sensor_alpha, sensor_colors=sensor_colors, ) # sensors_actors can still be None for item, actors in (sensors_actors or {}).items(): for actor in actors: self._add_actor(item, actor) if "helmet" in meg and pick_types(info, meg=True).size > 0: actor, _, _ = _plot_helmet( self._renderer, info, to_cf_t, head_mri_t, "mri", alpha=sensor_alpha["meg_helmet"], scale=1 if self._units == "m" else 1e3, ) self._add_actor("helmet", actor) self._renderer._update() def remove_sensors(self, kind=None): """Remove sensors from the rendered scene. Parameters ---------- kind : str | list | None If None, removes all sensor-related data including the helmet. Can be "meg", "eeg", "fnirs", "ecog", "seeg", "dbs" or "helmet" to remove that item. """ all_kinds = ("meg", "eeg", "fnirs", "ecog", "seeg", "dbs", "helmet") if kind is None: for item in all_kinds: self._remove(item, render=False) else: if isinstance(kind, str): kind = [kind] for this_kind in kind: _check_option("kind", this_kind, all_kinds) self._remove(this_kind, render=False) self._renderer._update() def add_text( self, x, y, text, name=None, color=None, opacity=1.0, row=0, col=0, font_size=None, justification=None, ): """Add a text to the visualization. Parameters ---------- x : float X coordinate. y : float Y coordinate. text : str Text to add. name : str Name of the text (text label can be updated using update_text()). color : tuple Color of the text. Default is the foreground color set during initialization (default is black or white depending on the background color). opacity : float Opacity of the text (default 1.0). row : int | None Row index of which brain to use. Default is the top row. col : int | None Column index of which brain to use. Default is the left-most column. font_size : float | None The font size to use. justification : str | None The text justification. """ _validate_type(name, (str, None), "name") name = text if name is None else name if "text" in self._actors and name in self._actors["text"]: raise ValueError(f"Text with the name {name} already exists") if color is None: color = self._fg_color for ri, ci, _ in self._iter_views("vol"): if (row is None or row == ri) and (col is None or col == ci): actor = self._renderer.text2d( x_window=x, y_window=y, text=text, color=color, size=font_size, justification=justification, ) if "text" not in self._actors: self._actors["text"] = dict() self._actors["text"][name] = actor def remove_text(self, name=None): """Remove text from the rendered scene. Parameters ---------- name : str | None Remove specific text by name. If None, all text will be removed. """ _validate_type(name, (str, None), "name") if name is None: for actor in self._actors["text"].values(): self._renderer.plotter.remove_actor(actor, render=False) self._actors.pop("text") else: names = [None] if "text" in self._actors: names += list(self._actors["text"].keys()) _check_option("name", name, names) self._renderer.plotter.remove_actor( self._actors["text"][name], render=False ) self._actors["text"].pop(name) self._renderer._update() def _configure_label_time_course(self): from ...label import read_labels_from_annot if not self.show_traces: return if self.mpl_canvas is None: self._configure_mplcanvas() else: self.clear_glyphs() self.traces_mode = "label" self.add_annotation(self.annot, color="w", alpha=0.75) # now plot the time line self.plot_time_line(update=False) self.mpl_canvas.update_plot() for hemi in self._hemis: labels = read_labels_from_annot( subject=self._subject, parc=self.annot, hemi=hemi, subjects_dir=self._subjects_dir, ) self._vertex_to_label_id[hemi] = np.full(self.geo[hemi].coords.shape[0], -1) self._annotation_labels[hemi] = labels for idx, label in enumerate(labels): self._vertex_to_label_id[hemi][label.vertices] = idx @fill_doc def add_annotation( self, annot, borders=True, alpha=1, hemi=None, remove_existing=True, color=None ): """Add an annotation file. Parameters ---------- annot : str | tuple Either path to annotation file or annotation name. Alternatively, the annotation can be specified as a ``(labels, ctab)`` tuple per hemisphere, i.e. ``annot=(labels, ctab)`` for a single hemisphere or ``annot=((lh_labels, lh_ctab), (rh_labels, rh_ctab))`` for both hemispheres. ``labels`` and ``ctab`` should be arrays as returned by :func:`nibabel.freesurfer.io.read_annot`. borders : bool | int Show only label borders. If int, specify the number of steps (away from the true border) along the cortical mesh to include as part of the border definition. %(alpha)s Default is 1. hemi : str | None If None, it is assumed to belong to the hemisphere being shown. If two hemispheres are being shown, data must exist for both hemispheres. remove_existing : bool If True (default), remove old annotations. color : matplotlib-style color code If used, show all annotations in the same (specified) color. Probably useful only when showing annotation borders. """ from ...label import _read_annot hemis = self._check_hemis(hemi) # Figure out where the data is coming from if _path_like(annot): if os.path.isfile(annot): filepath = _check_fname(annot, overwrite="read") file_hemi, annot = filepath.name.split(".", 1) if len(hemis) > 1: if file_hemi == "lh": filepaths = [filepath, filepath.parent / ("rh." + annot)] elif file_hemi == "rh": filepaths = [filepath.parent / ("lh." + annot), filepath] else: raise RuntimeError( "To add both hemispheres simultaneously, filename must " 'begin with "lh." or "rh."' ) else: filepaths = [filepath] else: filepaths = [] for hemi in hemis: filepath = op.join( self._subjects_dir, self._subject, "label", ".".join([hemi, annot, "annot"]), ) if not os.path.exists(filepath): raise ValueError(f"Annotation file {filepath} does not exist") filepaths += [filepath] annots = [] for hemi, filepath in zip(hemis, filepaths): # Read in the data labels, cmap, _ = _read_annot(filepath) annots.append((labels, cmap)) else: annots = [annot] if len(hemis) == 1 else annot annot = "annotation" for hemi, (labels, cmap) in zip(hemis, annots): # Maybe zero-out the non-border vertices self._to_borders(labels, hemi, borders) # Handle null labels properly cmap[:, 3] = 255 bgcolor = np.round(np.array(self._brain_color) * 255).astype(int) bgcolor[-1] = 0 cmap[cmap[:, 4] < 0, 4] += 2**24 # wrap to positive cmap[cmap[:, 4] <= 0, :4] = bgcolor if np.any(labels == 0) and not np.any(cmap[:, -1] <= 0): cmap = np.vstack((cmap, np.concatenate([bgcolor, [0]]))) # Set label ids sensibly order = np.argsort(cmap[:, -1]) cmap = cmap[order] ids = np.searchsorted(cmap[:, -1], labels) cmap = cmap[:, :4] # Set the alpha level alpha_vec = cmap[:, 3] alpha_vec[alpha_vec > 0] = alpha * 255 # Override the cmap when a single color is used if color is not None: rgb = np.round(np.multiply(_to_rgb(color), 255)) cmap[:, :3] = rgb.astype(cmap.dtype) ctable = cmap.astype(np.float64) for _ in self._iter_views(hemi): mesh = self._layered_meshes[hemi] mesh.add_overlay( scalars=ids, colormap=ctable, rng=[np.min(ids), np.max(ids)], opacity=alpha, name=annot, ) self._annots[hemi].append(annot) if not self.time_viewer or self.traces_mode == "vertex": self._renderer._set_colormap_range( mesh._actor, cmap.astype(np.uint8), None ) self._renderer._update() def close(self): """Close all figures and cleanup data structure.""" self._closed = True self._renderer.close() def show(self): """Display the window.""" from ..backends._utils import _qt_app_exec self._renderer.show() if self._block: _qt_app_exec(self._renderer.figure.store["app"]) @fill_doc def get_view(self, row=0, col=0, *, align=True): """Get the camera orientation for a given subplot display. Parameters ---------- row : int The row to use, default is the first one. col : int The column to check, the default is the first one. %(align_view)s Returns ------- %(roll)s %(distance)s %(azimuth)s %(elevation)s %(focalpoint)s """ row = _ensure_int(row, "row") col = _ensure_int(col, "col") rigid = self._rigid if align else None for h in self._hemis: for ri, ci, _ in self._iter_views(h): if (row == ri) and (col == ci): return self._renderer.get_camera(rigid=rigid) return (None,) * 5 @verbose def show_view( self, view=None, roll=None, distance=None, *, row=None, col=None, hemi=None, align=True, azimuth=None, elevation=None, focalpoint=None, update=True, verbose=None, ): """Orient camera to display view. Parameters ---------- %(view)s %(roll)s %(distance)s row : int | None The row to set. Default all rows. col : int | None The column to set. Default all columns. hemi : str | None Which hemi to use for view lookup (when in "both" mode). %(align_view)s %(azimuth)s %(elevation)s %(focalpoint)s %(brain_update)s .. versionadded:: 1.6 %(verbose)s Notes ----- The builtin string views are the following perspectives, based on the :term:`RAS` convention. If not otherwise noted, the view will have the top of the brain (superior, +Z) in 3D space shown upward in the 2D perspective: ``'lateral'`` From the left or right side such that the lateral (outside) surface of the given hemisphere is visible. ``'medial'`` From the left or right side such that the medial (inside) surface of the given hemisphere is visible (at least when in split or single-hemi mode). ``'rostral'`` From the front. ``'caudal'`` From the rear. ``'dorsal'`` From above, with the front of the brain pointing up. ``'ventral'`` From below, with the front of the brain pointing up. ``'frontal'`` From the front and slightly lateral, with the brain slightly tilted forward (yielding a view from slightly above). ``'parietal'`` From the rear and slightly lateral, with the brain slightly tilted backward (yielding a view from slightly above). ``'axial'`` From above with the brain pointing up (same as ``'dorsal'``). ``'sagittal'`` From the right side. ``'coronal'`` From the rear. Three letter abbreviations (e.g., ``'lat'``) of all of the above are also supported. """ _validate_type(row, ("int-like", None), "row") _validate_type(col, ("int-like", None), "col") hemi = self._hemi if hemi is None else hemi if hemi == "split": if ( self._view_layout == "vertical" and col == 1 or self._view_layout == "horizontal" and row == 1 ): hemi = "rh" else: hemi = "lh" _validate_type(view, (str, None), "view") view_params = dict( azimuth=azimuth, elevation=elevation, roll=roll, distance=distance, focalpoint=focalpoint, ) if view is not None: # view_params take precedence view_params = { param: val for param, val in view_params.items() if val is not None } # no overwriting with None view_params = dict(views_dicts[hemi].get(view), **view_params) for h in self._hemis: for ri, ci, _ in self._iter_views(h): if (row is None or row == ri) and (col is None or col == ci): self._set_camera(**view_params, align=align) if update: self._renderer._update() def _set_camera( self, *, distance=None, focalpoint=None, update=False, align=True, verbose=None, **kwargs, ): # Wrap to self._renderer.set_camera safely, always passing self._rigid # and using better no-op-like defaults return self._renderer.set_camera( distance=distance, focalpoint=focalpoint, update=update, rigid=self._rigid if align else None, **kwargs, ) def reset_view(self): """Reset the camera.""" for h in self._hemis: for _, _, v in self._iter_views(h): self._set_camera(**views_dicts[h][v]) self._renderer._update() def save_image(self, filename=None, mode="rgb"): """Save view from all panels to disk. Parameters ---------- filename : path-like Path to new image file. mode : str Either ``'rgb'`` or ``'rgba'`` for values to return. """ if filename is None: filename = _generate_default_filename(".png") _save_ndarray_img(filename, self.screenshot(mode=mode, time_viewer=True)) @fill_doc def screenshot(self, mode="rgb", time_viewer=False): """Generate a screenshot of current view. Parameters ---------- mode : str Either ``'rgb'`` or ``'rgba'`` for values to return. %(time_viewer_brain_screenshot)s Returns ------- screenshot : array Image pixel values. """ n_channels = 3 if mode == "rgb" else 4 img = self._renderer.screenshot(mode) logger.debug(f"Got screenshot of size {img.shape}") if ( time_viewer and self.time_viewer and self.show_traces and not self.separate_canvas ): from matplotlib.image import imread canvas = self.mpl_canvas.fig.canvas canvas.draw_idle() fig = self.mpl_canvas.fig with BytesIO() as output: # Need to pass dpi here so it uses the physical (HiDPI) DPI # rather than logical DPI when saving in most cases. # But when matplotlib uses HiDPI and VTK doesn't # (e.g., macOS w/Qt 5.14+ and VTK9) then things won't work, # so let's just calculate the DPI we need to get # the correct size output based on the widths being equal size_in = fig.get_size_inches() dpi = fig.get_dpi() want_size = tuple(x * dpi for x in size_in) n_pix = want_size[0] * want_size[1] logger.debug( f"Saving figure of size {size_in} @ {dpi} DPI " f"({want_size} = {n_pix} pixels)" ) # Sometimes there can be off-by-one errors here (e.g., # if in mpl int() rather than int(round()) is used to # compute the number of pixels) so rather than use "raw" # format and try to reshape ourselves, just write to PNG # and read it, which has the dimensions encoded for us. fig.savefig( output, dpi=dpi, format="png", facecolor=self._bg_color, edgecolor="none", ) output.seek(0) trace_img = imread(output, format="png")[:, :, :n_channels] trace_img = np.clip(np.round(trace_img * 255), 0, 255).astype(np.uint8) bgcolor = np.array(self._brain_color[:n_channels]) / 255 img = concatenate_images( [img, trace_img], bgcolor=bgcolor, n_channels=n_channels ) return img @fill_doc def update_lut(self, fmin=None, fmid=None, fmax=None, alpha=None): """Update the range of the color map. Parameters ---------- %(fmin_fmid_fmax)s %(alpha)s """ publish( self, ColormapRange( kind="distributed_source_power", fmin=fmin, fmid=fmid, fmax=fmax, alpha=alpha, ), ) @fill_doc def _update_colormap_range(self, fmin=None, fmid=None, fmax=None, alpha=None): """Update the range of the color map. Parameters ---------- %(fmin_fmid_fmax)s %(alpha)s """ args = f"{fmin}, {fmid}, {fmax}, {alpha}" logger.debug(f"Updating LUT with {args}") center = self._data["center"] colormap = self._data["colormap"] transparent = self._data["transparent"] lims = {key: self._data[key] for key in ("fmin", "fmid", "fmax")} _update_monotonic(lims, fmin=fmin, fmid=fmid, fmax=fmax) assert all(val is not None for val in lims.values()) self._data.update(lims) self._data["ctable"] = np.round( calculate_lut( colormap, alpha=1.0, center=center, transparent=transparent, **lims ) * 255 ).astype(np.uint8) # update our values rng = self._cmap_range ctable = self._data["ctable"] for hemi in ["lh", "rh", "vol"]: hemi_data = self._data.get(hemi) if hemi_data is not None: if hemi in self._layered_meshes: mesh = self._layered_meshes[hemi] mesh.update_overlay( name="data", colormap=self._data["ctable"], opacity=alpha, rng=rng, ) self._renderer._set_colormap_range( mesh._actor, ctable, self._scalar_bar, rng, self._brain_color ) grid_volume_pos = hemi_data.get("grid_volume_pos") grid_volume_neg = hemi_data.get("grid_volume_neg") for grid_volume in (grid_volume_pos, grid_volume_neg): if grid_volume is not None: self._renderer._set_volume_range( grid_volume, ctable, hemi_data["alpha"], self._scalar_bar, rng, ) glyph_actor = hemi_data.get("glyph_actor") if glyph_actor is not None: for glyph_actor_ in glyph_actor: self._renderer._set_colormap_range( glyph_actor_, ctable, self._scalar_bar, rng ) self._renderer._update() def set_data_smoothing(self, n_steps): """Set the number of smoothing steps. Parameters ---------- n_steps : int Number of smoothing steps. """ from ...morph import _hemi_morph for hemi in ["lh", "rh"]: hemi_data = self._data.get(hemi) if hemi_data is not None: if len(hemi_data["array"]) >= self.geo[hemi].x.shape[0]: continue vertices = hemi_data["vertices"] if vertices is None: raise ValueError( f"len(data) < nvtx ({len(hemi_data)} < " f"{self.geo[hemi].x.shape[0]}): the vertices " "parameter must not be None" ) morph_n_steps = "nearest" if n_steps == -1 else n_steps with use_log_level(False): smooth_mat = _hemi_morph( self.geo[hemi].orig_faces, np.arange(len(self.geo[hemi].coords)), vertices, morph_n_steps, maps=None, warn=False, ) self._data[hemi]["smooth_mat"] = smooth_mat self._update_current_time_idx(self._data["time_idx"]) self._data["smoothing_steps"] = n_steps @property def _n_times(self): return len(self._times) if self._times is not None else None @property def time_interpolation(self): """The interpolation mode.""" return self._time_interpolation @fill_doc def set_time_interpolation(self, interpolation): """Set the interpolation mode. Parameters ---------- %(interpolation_brain_time)s """ self._time_interpolation = _check_option( "interpolation", interpolation, ("linear", "nearest", "zero", "slinear", "quadratic", "cubic"), ) self._time_interp_funcs = dict() self._time_interp_inv = None if self._times is not None: idx = np.arange(self._n_times) for hemi in ["lh", "rh", "vol"]: hemi_data = self._data.get(hemi) if hemi_data is not None: array = hemi_data["array"] self._time_interp_funcs[hemi] = _safe_interp1d( idx, array, self._time_interpolation, axis=-1, assume_sorted=True, ) self._time_interp_inv = _safe_interp1d(idx, self._times) def _update_current_time_idx(self, time_idx): """Update all widgets in the figure to reflect a new time point. Parameters ---------- time_idx : int | float The time index to use. Can be a float to use interpolation between indices. """ self._current_act_data = dict() time_actor = self._data.get("time_actor", None) time_label = self._data.get("time_label", None) for hemi in ["lh", "rh", "vol"]: hemi_data = self._data.get(hemi) if hemi_data is not None: array = hemi_data["array"] # interpolate in time vectors = None if array.ndim == 1: act_data = array self._current_time = 0 else: act_data = self._time_interp_funcs[hemi](time_idx) self._current_time = self._time_interp_inv(time_idx) if array.ndim == 3: vectors = act_data act_data = np.linalg.norm(act_data, axis=1) self._current_time = self._time_interp_inv(time_idx) self._current_act_data[hemi] = act_data if time_actor is not None and time_label is not None: time_actor.SetInput(time_label(self._current_time)) # update the volume interpolation grid = hemi_data.get("grid") if grid is not None: vertices = self._data["vol"]["vertices"] values = self._current_act_data["vol"] rng = self._cmap_range fill = 0 if self._data["center"] is not None else rng[0] grid.cell_data["values"].fill(fill) # XXX for sided data, we probably actually need two # volumes as composite/MIP needs to look at two # extremes... for now just use abs. Eventually we can add # two volumes if we want. grid.cell_data["values"][vertices] = values # interpolate in space smooth_mat = hemi_data.get("smooth_mat") if smooth_mat is not None: act_data = smooth_mat.dot(act_data) # update the mesh scalar values if hemi in self._layered_meshes: mesh = self._layered_meshes[hemi] if "data" in mesh._overlays: mesh.update_overlay(name="data", scalars=act_data) else: mesh.add_overlay( scalars=act_data, colormap=self._data["ctable"], rng=self._cmap_range, opacity=None, name="data", ) # update the glyphs if vectors is not None: self._update_glyphs(hemi, vectors) self._data["time_idx"] = time_idx self._renderer._update() def set_time_point(self, time_idx): """Set the time point to display (can be a float to interpolate). Parameters ---------- time_idx : int | float The time index to use. Can be a float to use interpolation between indices. """ if self._times is None: raise ValueError("Cannot set time when brain has no defined times.") elif 0 <= time_idx <= len(self._times): publish(self, TimeChange(time=self._time_interp_inv(time_idx))) else: raise ValueError( f"Requested time point ({time_idx}) is outside the range of " f"available time points (0-{len(self._times)})." ) def set_time(self, time): """Set the time to display (in seconds). Parameters ---------- time : float The time to show, in seconds. """ if self._times is None: raise ValueError("Cannot set time when brain has no defined times.") elif min(self._times) <= time <= max(self._times): publish(self, TimeChange(time=time)) else: raise ValueError( f"Requested time ({time} s) is outside the range of " f"available times ({min(self._times)}-{max(self._times)} s)." ) def _update_glyphs(self, hemi, vectors): hemi_data = self._data.get(hemi) assert hemi_data is not None vertices = hemi_data["vertices"] vector_alpha = self._data["vector_alpha"] scale_factor = self._data["scale_factor"] vertices = slice(None) if vertices is None else vertices x, y, z = np.array(self.geo[hemi].coords)[vertices].T if hemi_data["glyph_actor"] is None: add = True hemi_data["glyph_actor"] = list() else: add = False count = 0 for _ in self._iter_views(hemi): if hemi_data["glyph_dataset"] is None: glyph_mapper, glyph_dataset = self._renderer.quiver3d( x, y, z, vectors[:, 0], vectors[:, 1], vectors[:, 2], color=None, mode="2darrow", scale_mode="vector", scale=scale_factor, opacity=vector_alpha, ) hemi_data["glyph_dataset"] = glyph_dataset hemi_data["glyph_mapper"] = glyph_mapper else: glyph_dataset = hemi_data["glyph_dataset"] glyph_dataset.point_data["vec"] = vectors glyph_mapper = hemi_data["glyph_mapper"] if add: glyph_actor = self._renderer._actor(glyph_mapper) prop = glyph_actor.GetProperty() prop.SetLineWidth(2.0) prop.SetOpacity(vector_alpha) self._renderer.plotter.add_actor(glyph_actor, render=False) hemi_data["glyph_actor"].append(glyph_actor) else: glyph_actor = hemi_data["glyph_actor"][count] count += 1 self._renderer._set_colormap_range( actor=glyph_actor, ctable=self._data["ctable"], scalar_bar=None, rng=self._cmap_range, ) @property def _cmap_range(self): dt_max = self._data["fmax"] if self._data["center"] is None: dt_min = self._data["fmin"] else: dt_min = -1 * dt_max rng = [dt_min, dt_max] return rng def _update_fscale(self, fscale): """Scale the colorbar points.""" fmin = self._data["fmin"] * fscale fmid = self._data["fmid"] * fscale fmax = self._data["fmax"] * fscale self.update_lut(fmin=fmin, fmid=fmid, fmax=fmax) def _update_auto_scaling(self, restore=False): user_clim = self._data["clim"] if user_clim is not None and "lims" in user_clim: allow_pos_lims = False else: allow_pos_lims = True if user_clim is not None and restore: clim = user_clim else: clim = "auto" colormap = self._data["colormap"] transparent = self._data["transparent"] mapdata = _process_clim( clim, colormap, transparent, np.concatenate(list(self._current_act_data.values())), allow_pos_lims, ) diverging = "pos_lims" in mapdata["clim"] colormap = mapdata["colormap"] scale_pts = mapdata["clim"]["pos_lims" if diverging else "lims"] transparent = mapdata["transparent"] del mapdata fmin, fmid, fmax = scale_pts center = 0.0 if diverging else None self._data["center"] = center self._data["colormap"] = colormap self._data["transparent"] = transparent self.update_lut(fmin=fmin, fmid=fmid, fmax=fmax) def _to_time_index(self, value): """Return the interpolated time index of the given time value.""" time = self._data["time"] value = np.interp(value, time, np.arange(len(time))) return value @property def data(self): """Data used by time viewer and color bar widgets.""" return self._data @property def labels(self): return self._labels @property def views(self): return self._views @property def hemis(self): return self._hemis def _save_movie( self, filename, time_dilation=4.0, tmin=None, tmax=None, framerate=24, interpolation=None, codec=None, bitrate=None, callback=None, time_viewer=False, **kwargs, ): import imageio with self._renderer._disabled_interaction(): images = self._make_movie_frames( time_dilation, tmin, tmax, framerate, interpolation, callback, time_viewer, ) # find imageio FFMPEG parameters if "fps" not in kwargs: kwargs["fps"] = framerate if codec is not None: kwargs["codec"] = codec if bitrate is not None: kwargs["bitrate"] = bitrate # when using GIF we need to convert FPS to duration in milliseconds for Pillow if str(filename).endswith(".gif"): kwargs["duration"] = 1000 * len(images) / kwargs.pop("fps") imageio.mimwrite(filename, images, **kwargs) def _save_movie_tv( self, filename, time_dilation=4.0, tmin=None, tmax=None, framerate=24, interpolation=None, codec=None, bitrate=None, callback=None, time_viewer=False, **kwargs, ): def frame_callback(frame, n_frames): if frame == n_frames: # On the ImageIO step self.status_msg.set_value(f"Saving with ImageIO: {filename}") self.status_msg.show() self.status_progress.hide() self._renderer._status_bar_update() else: self.status_msg.set_value( f"Rendering images (frame {frame + 1} / {n_frames}) ..." ) self.status_msg.show() self.status_progress.show() self.status_progress.set_range([0, n_frames - 1]) self.status_progress.set_value(frame) self.status_progress.update() self.status_msg.update() self._renderer._status_bar_update() # set cursor to busy default_cursor = self._renderer._window_get_cursor() self._renderer._window_set_cursor( self._renderer._window_new_cursor("WaitCursor") ) try: self._save_movie( filename, time_dilation, tmin, tmax, framerate, interpolation, codec, bitrate, frame_callback, time_viewer, **kwargs, ) except (Exception, KeyboardInterrupt): warn("Movie saving aborted:\n" + traceback.format_exc()) finally: self._renderer._window_set_cursor(default_cursor) @fill_doc def save_movie( self, filename=None, time_dilation=4.0, tmin=None, tmax=None, framerate=24, interpolation=None, codec=None, bitrate=None, callback=None, time_viewer=False, **kwargs, ): """Save a movie (for data with a time axis). The movie is created through the :mod:`imageio` module. The format is determined by the extension, and additional options can be specified through keyword arguments that depend on the format, see :doc:`imageio's format page `. .. Warning:: This method assumes that time is specified in seconds when adding data. If time is specified in milliseconds this will result in movies 1000 times longer than expected. Parameters ---------- filename : str Path at which to save the movie. The extension determines the format (e.g., ``'*.mov'``, ``'*.gif'``, ...; see the :mod:`imageio` documentation for available formats). time_dilation : float Factor by which to stretch time (default 4). For example, an epoch from -100 to 600 ms lasts 700 ms. With ``time_dilation=4`` this would result in a 2.8 s long movie. tmin : float First time point to include (default: all data). tmax : float Last time point to include (default: all data). framerate : float Framerate of the movie (frames per second, default 24). %(interpolation_brain_time)s If None, it uses the current ``brain.interpolation``, which defaults to ``'nearest'``. Defaults to None. codec : str | None The codec to use. bitrate : float | None The bitrate to use. callback : callable | None A function to call on each iteration. Useful for status message updates. It will be passed keyword arguments ``frame`` and ``n_frames``. %(time_viewer_brain_screenshot)s **kwargs : dict Specify additional options for :mod:`imageio`. """ if filename is None: filename = _generate_default_filename(".mp4") func = self._save_movie_tv if self.time_viewer else self._save_movie func( filename, time_dilation, tmin, tmax, framerate, interpolation, codec, bitrate, callback, time_viewer, **kwargs, ) def _make_movie_frames( self, time_dilation, tmin, tmax, framerate, interpolation, callback, time_viewer ): from math import floor # find tmin if tmin is None: tmin = self._times[0] elif tmin < self._times[0]: raise ValueError( f"tmin={repr(tmin)} is smaller than the first time point " f"({repr(self._times[0])})" ) # find indexes at which to create frames if tmax is None: tmax = self._times[-1] elif tmax > self._times[-1]: raise ValueError( f"tmax={repr(tmax)} is greater than the latest time point " f"({repr(self._times[-1])})" ) n_frames = floor((tmax - tmin) * time_dilation * framerate) times = np.arange(n_frames, dtype=float) times /= framerate * time_dilation times += tmin time_idx = np.interp(times, self._times, np.arange(self._n_times)) n_times = len(time_idx) if n_times == 0: raise ValueError("No time points selected") logger.debug(f"Save movie for time points/samples\n{times}\n{time_idx}") # Sometimes the first screenshot is rendered with a different # resolution on OS X self.screenshot(time_viewer=time_viewer) old_mode = self.time_interpolation if interpolation is not None: self.set_time_interpolation(interpolation) try: images = [ self.screenshot(time_viewer=time_viewer) for _ in self._iter_time(time_idx, callback) ] finally: self.set_time_interpolation(old_mode) if callback is not None: callback(frame=len(time_idx), n_frames=len(time_idx)) return images def _iter_time(self, time_idx, callback): """Iterate through time points, then reset to current time. Parameters ---------- time_idx : array_like Time point indexes through which to iterate. callback : callable | None Callback to call before yielding each frame. Yields ------ idx : int | float Current index. Notes ----- Used by movie and image sequence saving functions. """ current_time_idx = self._data["time_idx"] for ii, idx in enumerate(time_idx): self.set_time_point(idx) if callback is not None: callback(frame=ii, n_frames=len(time_idx)) yield idx # Restore original time index self.set_time_point(current_time_idx) def _check_stc(self, hemi, array, vertices): from ...source_estimate import ( _BaseMixedSourceEstimate, _BaseSourceEstimate, _BaseSurfaceSourceEstimate, _BaseVolSourceEstimate, ) if isinstance(array, _BaseSourceEstimate): stc = array stc_surf = stc_vol = None if isinstance(stc, _BaseSurfaceSourceEstimate): stc_surf = stc elif isinstance(stc, _BaseMixedSourceEstimate): stc_surf = stc.surface() if hemi != "vol" else None stc_vol = stc.volume() if hemi == "vol" else None elif isinstance(stc, _BaseVolSourceEstimate): stc_vol = stc if hemi == "vol" else None else: raise TypeError("stc not supported") if stc_surf is None and stc_vol is None: raise ValueError("No data to be added") if stc_surf is not None: array = getattr(stc_surf, hemi + "_data") vertices = stc_surf.vertices[0 if hemi == "lh" else 1] if stc_vol is not None: array = stc_vol.data vertices = np.concatenate(stc_vol.vertices) else: stc = None return stc, array, vertices def _check_hemi(self, hemi, extras=()): """Check for safe single-hemi input, returns str.""" _validate_type(hemi, (None, str), "hemi") if hemi is None: if self._hemi not in ["lh", "rh"]: raise ValueError( "hemi must not be None when both hemispheres are displayed" ) hemi = self._hemi _check_option("hemi", hemi, ("lh", "rh") + tuple(extras)) return hemi def _check_hemis(self, hemi): """Check for safe dual or single-hemi input, returns list.""" if hemi is None: if self._hemi not in ["lh", "rh"]: hemi = ["lh", "rh"] else: hemi = [self._hemi] elif hemi not in ["lh", "rh"]: extra = " or None" if self._hemi in ["lh", "rh"] else "" raise ValueError('hemi must be either "lh" or "rh"' + extra) else: hemi = [hemi] return hemi def _to_borders(self, label, hemi, borders, restrict_idx=None): """Convert a label/parc to borders.""" if not isinstance(borders, (bool, int)) or borders < 0: raise ValueError("borders must be a bool or positive integer") if borders: n_vertices = label.size edges = mesh_edges(self.geo[hemi].orig_faces) edges = edges.tocoo() border_edges = label[edges.row] != label[edges.col] show = np.zeros(n_vertices, dtype=np.int64) keep_idx = np.unique(edges.row[border_edges]) if isinstance(borders, int): for _ in range(borders): keep_idx = np.isin(self.geo[hemi].orig_faces.ravel(), keep_idx) keep_idx.shape = self.geo[hemi].orig_faces.shape keep_idx = self.geo[hemi].orig_faces[np.any(keep_idx, axis=1)] keep_idx = np.unique(keep_idx) if restrict_idx is not None: keep_idx = keep_idx[np.isin(keep_idx, restrict_idx)] show[keep_idx] = 1 label *= show def get_picked_points(self): """Return the vertices of the picked points. Returns ------- points : list of int | None The vertices picked by the time viewer. """ if hasattr(self, "time_viewer"): return self.picked_points def __hash__(self): """Hash the object.""" return self._hash def _safe_interp1d(x, y, kind="linear", axis=-1, assume_sorted=False): """Work around interp1d not liking singleton dimensions.""" if y.shape[axis] == 1: def func(x): return np.take(y, np.zeros(np.asarray(x).shape, int), axis=axis) return func else: return interp1d(x, y, kind, axis=axis, assume_sorted=assume_sorted) def _update_limits(fmin, fmid, fmax, center, array): if center is None: if fmin is None: fmin = array.min() if array.size > 0 else 0 if fmax is None: fmax = array.max() if array.size > 0 else 1 else: if fmin is None: fmin = 0 if fmax is None: fmax = np.abs(center - array).max() if array.size > 0 else 1 if fmid is None: fmid = (fmin + fmax) / 2.0 if fmin >= fmid: raise RuntimeError(f"min must be < mid, got {fmin:0.4g} >= {fmid:0.4g}") if fmid >= fmax: raise RuntimeError(f"mid must be < max, got {fmid:0.4g} >= {fmax:0.4g}") return fmin, fmid, fmax def _update_monotonic(lims, fmin, fmid, fmax): if fmin is not None: lims["fmin"] = fmin if lims["fmax"] < fmin: logger.debug(f' Bumping fmax = {lims["fmax"]} to {fmin}') lims["fmax"] = fmin if lims["fmid"] < fmin: logger.debug(f' Bumping fmid = {lims["fmid"]} to {fmin}') lims["fmid"] = fmin assert lims["fmin"] <= lims["fmid"] <= lims["fmax"] if fmid is not None: lims["fmid"] = fmid if lims["fmin"] > fmid: logger.debug(f' Bumping fmin = {lims["fmin"]} to {fmid}') lims["fmin"] = fmid if lims["fmax"] < fmid: logger.debug(f' Bumping fmax = {lims["fmax"]} to {fmid}') lims["fmax"] = fmid assert lims["fmin"] <= lims["fmid"] <= lims["fmax"] if fmax is not None: lims["fmax"] = fmax if lims["fmin"] > fmax: logger.debug(f' Bumping fmin = {lims["fmin"]} to {fmax}') lims["fmin"] = fmax if lims["fmid"] > fmax: logger.debug(f' Bumping fmid = {lims["fmid"]} to {fmax}') lims["fmid"] = fmax assert lims["fmin"] <= lims["fmid"] <= lims["fmax"] def _get_range(brain): """Get the data limits. Since they may be very small (1E-10 and such), we apply a scaling factor such that the data range lies somewhere between 0.01 and 100. This makes for more usable sliders. When setting a value on the slider, the value is multiplied by the scaling factor and when getting a value, this value should be divided by the scaling factor. """ fmax = abs(brain._data["fmax"]) if 1e-02 <= fmax <= 1e02: fscale_power = 0 else: fscale_power = int(np.log10(max(fmax, np.finfo("float32").smallest_normal))) if fscale_power < 0: fscale_power -= 1 fscale = 10**-fscale_power return fmax, fscale, fscale_power class _FakeIren: def EnterEvent(self): pass def MouseMoveEvent(self): pass def LeaveEvent(self): pass def SetEventInformation(self, *args, **kwargs): pass def CharEvent(self): pass def KeyPressEvent(self, *args, **kwargs): pass def KeyReleaseEvent(self, *args, **kwargs): pass