210 lines
		
	
	
		
			6.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			210 lines
		
	
	
		
			6.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Authors: The MNE-Python contributors.
 | 
						|
# License: BSD-3-Clause
 | 
						|
# Copyright the MNE-Python contributors.
 | 
						|
 | 
						|
import numpy as np
 | 
						|
from scipy.ndimage import gaussian_filter
 | 
						|
 | 
						|
from ..._fiff.constants import FIFF
 | 
						|
from ...utils import _validate_type, fill_doc, logger
 | 
						|
from ..utils import plt_show
 | 
						|
 | 
						|
 | 
						|
@fill_doc
 | 
						|
def plot_gaze(
 | 
						|
    epochs,
 | 
						|
    *,
 | 
						|
    calibration=None,
 | 
						|
    width=None,
 | 
						|
    height=None,
 | 
						|
    sigma=25,
 | 
						|
    cmap=None,
 | 
						|
    alpha=1.0,
 | 
						|
    vlim=(None, None),
 | 
						|
    axes=None,
 | 
						|
    show=True,
 | 
						|
):
 | 
						|
    """Plot a heatmap of eyetracking gaze data.
 | 
						|
 | 
						|
    Parameters
 | 
						|
    ----------
 | 
						|
    epochs : instance of Epochs
 | 
						|
        The :class:`~mne.Epochs` object containing eyegaze channels.
 | 
						|
    calibration : instance of Calibration | None
 | 
						|
        An instance of Calibration with information about the screen size, distance,
 | 
						|
        and resolution. If ``None``, you must provide a width and height.
 | 
						|
    width : int
 | 
						|
        The width dimension of the plot canvas, only valid if eyegaze data are in
 | 
						|
        pixels. For example, if the participant screen resolution was 1920x1080, then
 | 
						|
        the width should be 1920.
 | 
						|
    height : int
 | 
						|
        The height dimension of the plot canvas, only valid if eyegaze data are in
 | 
						|
        pixels. For example, if the participant screen resolution was 1920x1080, then
 | 
						|
        the height should be 1080.
 | 
						|
    sigma : float | None
 | 
						|
        The amount of Gaussian smoothing applied to the heatmap data (standard
 | 
						|
        deviation in pixels). If ``None``, no smoothing is applied. Default is 25.
 | 
						|
    %(cmap)s
 | 
						|
    alpha : float
 | 
						|
        The opacity of the heatmap (default is 1).
 | 
						|
    %(vlim_plot_topomap)s
 | 
						|
    %(axes_plot_topomap)s
 | 
						|
    %(show)s
 | 
						|
 | 
						|
    Returns
 | 
						|
    -------
 | 
						|
    fig : instance of Figure
 | 
						|
        The resulting figure object for the heatmap plot.
 | 
						|
 | 
						|
    Notes
 | 
						|
    -----
 | 
						|
    .. versionadded:: 1.6
 | 
						|
    """
 | 
						|
    from mne import BaseEpochs
 | 
						|
    from mne._fiff.pick import _picks_to_idx
 | 
						|
 | 
						|
    from ...preprocessing.eyetracking.utils import (
 | 
						|
        _check_calibration,
 | 
						|
        get_screen_visual_angle,
 | 
						|
    )
 | 
						|
 | 
						|
    _validate_type(epochs, BaseEpochs, "epochs")
 | 
						|
    _validate_type(alpha, "numeric", "alpha")
 | 
						|
    _validate_type(sigma, ("numeric", None), "sigma")
 | 
						|
 | 
						|
    # Get the gaze data
 | 
						|
    pos_picks = _picks_to_idx(epochs.info, "eyegaze")
 | 
						|
    gaze_data = epochs.get_data(picks=pos_picks)
 | 
						|
    gaze_ch_loc = np.array([epochs.info["chs"][idx]["loc"] for idx in pos_picks])
 | 
						|
    x_data = gaze_data[:, np.where(gaze_ch_loc[:, 4] == -1)[0], :]
 | 
						|
    y_data = gaze_data[:, np.where(gaze_ch_loc[:, 4] == 1)[0], :]
 | 
						|
    unit = epochs.info["chs"][pos_picks[0]]["unit"]  # assumes all units are the same
 | 
						|
 | 
						|
    if x_data.shape[1] > 1:  # binocular recording. Average across eyes
 | 
						|
        logger.info("Detected binocular recording. Averaging positions across eyes.")
 | 
						|
        x_data = np.nanmean(x_data, axis=1)  # shape (n_epochs, n_samples)
 | 
						|
        y_data = np.nanmean(y_data, axis=1)
 | 
						|
    canvas = np.vstack((x_data.flatten(), y_data.flatten()))  # shape (2, n_samples)
 | 
						|
 | 
						|
    # Check that we have the right inputs
 | 
						|
    if calibration is not None:
 | 
						|
        if width is not None or height is not None:
 | 
						|
            raise ValueError(
 | 
						|
                "If a calibration is provided, you cannot provide a width or height"
 | 
						|
                " to plot heatmaps. Please provide only the calibration object."
 | 
						|
            )
 | 
						|
        _check_calibration(calibration)
 | 
						|
        if unit == FIFF.FIFF_UNIT_PX:
 | 
						|
            width, height = calibration["screen_resolution"]
 | 
						|
        elif unit == FIFF.FIFF_UNIT_RAD:
 | 
						|
            width, height = calibration["screen_size"]
 | 
						|
        else:
 | 
						|
            raise ValueError(
 | 
						|
                f"Invalid unit type: {unit}. gaze data Must be pixels or radians."
 | 
						|
            )
 | 
						|
    else:
 | 
						|
        if width is None or height is None:
 | 
						|
            raise ValueError(
 | 
						|
                "If no calibration is provided, you must provide a width and height"
 | 
						|
                " to plot heatmaps."
 | 
						|
            )
 | 
						|
 | 
						|
    # Create 2D histogram
 | 
						|
    # We need to set the histogram bins & bounds, and imshow extent, based on the units
 | 
						|
    if unit == FIFF.FIFF_UNIT_PX:  # pixel on screen
 | 
						|
        _range = [[0, height], [0, width]]
 | 
						|
        bins_x, bins_y = width, height
 | 
						|
        extent = [0, width, height, 0]
 | 
						|
    elif unit == FIFF.FIFF_UNIT_RAD:  # radians of visual angle
 | 
						|
        if not calibration:
 | 
						|
            raise ValueError(
 | 
						|
                "If gaze data are in Radians, you must provide a"
 | 
						|
                " calibration instance to plot heatmaps."
 | 
						|
            )
 | 
						|
        width, height = get_screen_visual_angle(calibration)
 | 
						|
        x_range = [-width / 2, width / 2]
 | 
						|
        y_range = [-height / 2, height / 2]
 | 
						|
        _range = [y_range, x_range]
 | 
						|
        extent = (x_range[0], x_range[1], y_range[0], y_range[1])
 | 
						|
        bins_x, bins_y = calibration["screen_resolution"]
 | 
						|
 | 
						|
    hist, _, _ = np.histogram2d(
 | 
						|
        canvas[1, :],
 | 
						|
        canvas[0, :],
 | 
						|
        bins=(bins_y, bins_x),
 | 
						|
        range=_range,
 | 
						|
    )
 | 
						|
    # Convert density from samples to seconds
 | 
						|
    hist /= epochs.info["sfreq"]
 | 
						|
    # Smooth the heatmap
 | 
						|
    if sigma:
 | 
						|
        hist = gaussian_filter(hist, sigma=sigma)
 | 
						|
 | 
						|
    return _plot_heatmap_array(
 | 
						|
        hist,
 | 
						|
        width=width,
 | 
						|
        height=height,
 | 
						|
        cmap=cmap,
 | 
						|
        alpha=alpha,
 | 
						|
        vmin=vlim[0],
 | 
						|
        vmax=vlim[1],
 | 
						|
        extent=extent,
 | 
						|
        axes=axes,
 | 
						|
        show=show,
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def _plot_heatmap_array(
 | 
						|
    data,
 | 
						|
    width,
 | 
						|
    height,
 | 
						|
    *,
 | 
						|
    cmap=None,
 | 
						|
    alpha=None,
 | 
						|
    vmin=None,
 | 
						|
    vmax=None,
 | 
						|
    extent=None,
 | 
						|
    axes=None,
 | 
						|
    show=True,
 | 
						|
):
 | 
						|
    """Plot a heatmap of eyetracking gaze data from a numpy array."""
 | 
						|
    import matplotlib.pyplot as plt
 | 
						|
 | 
						|
    # Prepare axes
 | 
						|
    if axes is not None:
 | 
						|
        from matplotlib.axes import Axes
 | 
						|
 | 
						|
        _validate_type(axes, Axes, "axes")
 | 
						|
        ax = axes
 | 
						|
        fig = ax.get_figure()
 | 
						|
    else:
 | 
						|
        fig, ax = plt.subplots(constrained_layout=True)
 | 
						|
 | 
						|
    ax.set_title("Gaze heatmap")
 | 
						|
    ax.set_xlabel("X position")
 | 
						|
    ax.set_ylabel("Y position")
 | 
						|
 | 
						|
    # Prepare the heatmap
 | 
						|
    alphas = 1 if alpha is None else alpha
 | 
						|
    vmin = np.nanmin(data) if vmin is None else vmin
 | 
						|
    vmax = np.nanmax(data) if vmax is None else vmax
 | 
						|
    if extent is None:
 | 
						|
        extent = [0, width, height, 0]
 | 
						|
 | 
						|
    # Plot heatmap
 | 
						|
    im = ax.imshow(
 | 
						|
        data,
 | 
						|
        aspect="equal",
 | 
						|
        cmap=cmap,
 | 
						|
        alpha=alphas,
 | 
						|
        extent=extent,
 | 
						|
        origin="upper",
 | 
						|
        vmin=vmin,
 | 
						|
        vmax=vmax,
 | 
						|
    )
 | 
						|
 | 
						|
    # Prepare the colorbar
 | 
						|
    fig.colorbar(im, ax=ax, shrink=0.6, label="Dwell time (seconds)")
 | 
						|
    plt_show(show)
 | 
						|
    return fig
 |