# Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. from functools import partial import numpy as np from scipy.spatial.distance import cdist from ...utils import _check_option, _validate_type, fill_doc def _check_stc(stc1, stc2): """Check that stcs are compatible.""" if stc1.data.shape != stc2.data.shape: raise ValueError("Data in stcs must have the same size") if np.all(stc1.times != stc2.times): raise ValueError("Times of two stcs must match.") def source_estimate_quantification(stc1, stc2, metric="rms"): """Calculate STC similarities across all sources and times. Parameters ---------- stc1 : SourceEstimate First source estimate for comparison. stc2 : SourceEstimate Second source estimate for comparison. metric : str Metric to calculate, ``'rms'`` or ``'cosine'``. Returns ------- score : float | array Calculated metric. Notes ----- Metric calculation has multiple options: * rms: Root mean square of difference between stc data matrices. * cosine: Normalized correlation of all elements in stc data matrices. .. versionadded:: 0.10.0 """ _check_option("metric", metric, ["rms", "cosine"]) # This is checking that the data are having the same size meaning # no comparison between distributed and sparse can be done so far. _check_stc(stc1, stc2) data1, data2 = stc1.data, stc2.data # Calculate root mean square difference between two matrices if metric == "rms": score = np.sqrt(np.mean((data1 - data2) ** 2)) # Calculate correlation coefficient between matrix elements elif metric == "cosine": score = 1.0 - _cosine(data1, data2) return score def _uniform_stc(stc1, stc2): """Uniform vertices of two stcs. This function returns the stcs with the same vertices by inserting zeros in data for missing vertices. """ if len(stc1.vertices) != len(stc2.vertices): raise ValueError( "Data in stcs must have the same number of vertices " f"components. Got {len(stc1.vertices)} != {len(stc2.vertices)}." ) idx_start1 = 0 idx_start2 = 0 stc1 = stc1.copy() stc2 = stc2.copy() all_data1 = [] all_data2 = [] for i, (vert1, vert2) in enumerate(zip(stc1.vertices, stc2.vertices)): vert = np.union1d(vert1, vert2) data1 = np.zeros([len(vert), stc1.data.shape[1]]) data2 = np.zeros([len(vert), stc2.data.shape[1]]) data1[np.searchsorted(vert, vert1)] = stc1.data[ idx_start1 : idx_start1 + len(vert1) ] data2[np.searchsorted(vert, vert2)] = stc2.data[ idx_start2 : idx_start2 + len(vert2) ] idx_start1 += len(vert1) idx_start2 += len(vert2) stc1.vertices[i] = vert stc2.vertices[i] = vert all_data1.append(data1) all_data2.append(data2) stc1._data = np.concatenate(all_data1, axis=0) stc2._data = np.concatenate(all_data2, axis=0) return stc1, stc2 def _apply(func, stc_true, stc_est, per_sample): """Apply metric to stcs. Applies a metric to each pair of columns of stc_true and stc_est if per_sample is True. Otherwise it applies it to stc_true and stc_est directly. """ if per_sample: metric = np.empty(stc_true.data.shape[1]) # one value per time point for i in range(stc_true.data.shape[1]): metric[i] = func(stc_true.data[:, i : i + 1], stc_est.data[:, i : i + 1]) else: metric = func(stc_true.data, stc_est.data) return metric def _thresholding(stc_true, stc_est, threshold): relative = isinstance(threshold, str) threshold = _check_threshold(threshold) if relative: if stc_true is not None: stc_true._data[ np.abs(stc_true._data) <= threshold * np.max(np.abs(stc_true._data)) ] = 0.0 stc_est._data[ np.abs(stc_est._data) <= threshold * np.max(np.abs(stc_est._data)) ] = 0.0 else: if stc_true is not None: stc_true._data[np.abs(stc_true._data) <= threshold] = 0.0 stc_est._data[np.abs(stc_est._data) <= threshold] = 0.0 return stc_true, stc_est def _cosine(x, y): p = x.ravel() q = y.ravel() p_norm = np.linalg.norm(p) q_norm = np.linalg.norm(q) if p_norm * q_norm: return (p.T @ q) / (p_norm * q_norm) elif p_norm == q_norm: return 1 else: return 0 @fill_doc def cosine_score(stc_true, stc_est, per_sample=True): """Compute cosine similarity between 2 source estimates. Parameters ---------- %(stc_true_metric)s %(stc_est_metric)s %(per_sample_metric)s Returns ------- %(stc_metric)s Notes ----- .. versionadded:: 1.2 """ stc_true, stc_est = _uniform_stc(stc_true, stc_est) metric = _apply(_cosine, stc_true, stc_est, per_sample=per_sample) return metric def _check_threshold(threshold): """Accept a float or a string that ends with %.""" _validate_type(threshold, ("numeric", str), "threshold") if isinstance(threshold, str): if not threshold.endswith("%"): raise ValueError( f'Threshold if a string must end with "%". Got {threshold}.' ) threshold = float(threshold[:-1]) / 100.0 threshold = float(threshold) if not 0 <= threshold <= 1: raise ValueError( "Threshold proportion must be between 0 and 1 (inclusive), but " f"got {threshold}" ) return threshold def _abs_col_sum(x): return np.abs(x).sum(axis=1) def _dle(p, q, src, stc): """Aux function to compute dipole localization error.""" p = _abs_col_sum(p) q = _abs_col_sum(q) idx1 = np.nonzero(p)[0] idx2 = np.nonzero(q)[0] points = [] for i in range(len(src)): points.append(src[i]["rr"][stc.vertices[i]]) points = np.concatenate(points, axis=0) if len(idx1) and len(idx2): D = cdist(points[idx1], points[idx2]) D_min_1 = np.min(D, axis=0) D_min_2 = np.min(D, axis=1) return (np.mean(D_min_1) + np.mean(D_min_2)) / 2.0 else: return np.inf @fill_doc def region_localization_error(stc_true, stc_est, src, threshold="90%", per_sample=True): r"""Compute region localization error (RLE) between 2 source estimates. .. math:: RLE = \frac{1}{2Q}\sum_{k \in I} \min_{l \in \hat{I}}{||r_k - r_l||} + \frac{1}{2\hat{Q}}\sum_{l \in \hat{I}} \min_{k \in I}{||r_k - r_l||} where :math:`I` and :math:`\hat{I}` denote respectively the original and estimated indexes of active sources, :math:`Q` and :math:`\hat{Q}` are the numbers of original and estimated active sources. :math:`r_k` denotes the position of the k-th source dipole in space and :math:`||\cdot||` is an Euclidean norm in :math:`\mathbb{R}^3`. Parameters ---------- %(stc_true_metric)s %(stc_est_metric)s src : instance of SourceSpaces The source space on which the source estimates are defined. threshold : float | str The threshold to apply to source estimates before computing the dipole localization error. If a string the threshold is a percentage and it should end with the percent character. %(per_sample_metric)s Returns ------- %(stc_metric)s Notes ----- Papers :footcite:`MaksymenkoEtAl2017` and :footcite:`BeckerEtAl2017` use term Dipole Localization Error (DLE) for the same formula. Paper :footcite:`YaoEtAl2005` uses term Error Distance (ED) for the same formula. To unify the terminology and to avoid confusion with other cases of using term DLE but for different metric :footcite:`MolinsEtAl2008`, we use term Region Localization Error (RLE). .. versionadded:: 1.2 References ---------- .. footbibliography:: """ # noqa: E501 stc_true, stc_est = _uniform_stc(stc_true, stc_est) stc_true, stc_est = _thresholding(stc_true, stc_est, threshold) func = partial(_dle, src=src, stc=stc_true) metric = _apply(func, stc_true, stc_est, per_sample=per_sample) return metric def _roc_auc_score(p, q): from sklearn.metrics import roc_auc_score return roc_auc_score(np.abs(p) > 0, np.abs(q)) @fill_doc def roc_auc_score(stc_true, stc_est, per_sample=True): """Compute ROC AUC between 2 source estimates. ROC stands for receiver operating curve and AUC is Area under the curve. When computing this metric the stc_true must be thresholded as any non-zero value will be considered as a positive. The ROC-AUC metric is computed between amplitudes of the source estimates, i.e. after taking the absolute values. Parameters ---------- %(stc_true_metric)s %(stc_est_metric)s %(per_sample_metric)s Returns ------- %(stc_metric)s Notes ----- .. versionadded:: 1.2 """ stc_true, stc_est = _uniform_stc(stc_true, stc_est) metric = _apply(_roc_auc_score, stc_true, stc_est, per_sample=per_sample) return metric def _f1_score(p, q): from sklearn.metrics import f1_score return f1_score(_abs_col_sum(p) > 0, _abs_col_sum(q) > 0) @fill_doc def f1_score(stc_true, stc_est, threshold="90%", per_sample=True): """Compute the F1 score, also known as balanced F-score or F-measure. The F1 score can be interpreted as a weighted average of the precision and recall, where an F1 score reaches its best value at 1 and worst score at 0. The relative contribution of precision and recall to the F1 score are equal. The formula for the F1 score is:: F1 = 2 * (precision * recall) / (precision + recall) Threshold is used first for data binarization. Parameters ---------- %(stc_true_metric)s %(stc_est_metric)s threshold : float | str The threshold to apply to source estimates before computing the f1 score. If a string the threshold is a percentage and it should end with the percent character. %(per_sample_metric)s Returns ------- %(stc_metric)s Notes ----- .. versionadded:: 1.2 """ stc_true, stc_est = _uniform_stc(stc_true, stc_est) stc_true, stc_est = _thresholding(stc_true, stc_est, threshold) metric = _apply(_f1_score, stc_true, stc_est, per_sample=per_sample) return metric def _precision_score(p, q): from sklearn.metrics import precision_score return precision_score(_abs_col_sum(p) > 0, _abs_col_sum(q) > 0) @fill_doc def precision_score(stc_true, stc_est, threshold="90%", per_sample=True): """Compute the precision. The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of true positives and ``fp`` the number of false positives. The precision is intuitively the ability of the classifier not to label as positive a sample that is negative. The best value is 1 and the worst value is 0. Threshold is used first for data binarization. Parameters ---------- %(stc_true_metric)s %(stc_est_metric)s threshold : float | str The threshold to apply to source estimates before computing the precision. If a string the threshold is a percentage and it should end with the percent character. %(per_sample_metric)s Returns ------- %(stc_metric)s Notes ----- .. versionadded:: 1.2 """ stc_true, stc_est = _uniform_stc(stc_true, stc_est) stc_true, stc_est = _thresholding(stc_true, stc_est, threshold) metric = _apply(_precision_score, stc_true, stc_est, per_sample=per_sample) return metric def _recall_score(p, q): from sklearn.metrics import recall_score return recall_score(_abs_col_sum(p) > 0, _abs_col_sum(q) > 0) @fill_doc def recall_score(stc_true, stc_est, threshold="90%", per_sample=True): """Compute the recall. The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of true positives and ``fn`` the number of false negatives. The recall is intuitively the ability of the classifier to find all the positive samples. The best value is 1 and the worst value is 0. Threshold is used first for data binarization. Parameters ---------- %(stc_true_metric)s %(stc_est_metric)s threshold : float | str The threshold to apply to source estimates before computing the recall. If a string the threshold is a percentage and it should end with the percent character. %(per_sample_metric)s Returns ------- %(stc_metric)s Notes ----- .. versionadded:: 1.2 """ stc_true, stc_est = _uniform_stc(stc_true, stc_est) stc_true, stc_est = _thresholding(stc_true, stc_est, threshold) metric = _apply(_recall_score, stc_true, stc_est, per_sample=per_sample) return metric def _prepare_ppe_sd(stc_true, stc_est, src, threshold="50%"): stc_true = stc_true.copy() stc_est = stc_est.copy() n_dipoles = 0 for i, v in enumerate(stc_true.vertices): if len(v): n_dipoles += len(v) r_true = src[i]["rr"][v] if n_dipoles != 1: raise ValueError(f"True source must contain only one dipole, got {n_dipoles}.") _, stc_est = _thresholding(None, stc_est, threshold) r_est = np.empty([0, 3]) for i, v in enumerate(stc_est.vertices): if len(v): r_est = np.vstack([r_est, src[i]["rr"][v]]) return stc_est, r_true, r_est def _peak_position_error(p, q, r_est, r_true): q = _abs_col_sum(q) if np.sum(q): q /= np.sum(q) r_est_mean = np.dot(q, r_est) return np.linalg.norm(r_est_mean - r_true) else: return np.inf @fill_doc def peak_position_error(stc_true, stc_est, src, threshold="50%", per_sample=True): r"""Compute the peak position error. The peak position error measures the distance between the center-of-mass of the estimated and the true source. .. math:: PPE = \| \dfrac{\sum_i|s_i|r_{i}}{\sum_i|s_i|} - r_{true}\|, where :math:`r_{true}` is a true dipole position, :math:`r_i` and :math:`|s_i|` denote respectively the position and amplitude of i-th dipole in source estimate. Threshold is used on estimated source for focusing the metric to strong amplitudes and omitting the low-amplitude values. Parameters ---------- %(stc_true_metric)s %(stc_est_metric)s src : instance of SourceSpaces The source space on which the source estimates are defined. threshold : float | str The threshold to apply to source estimates before computing the recall. If a string the threshold is a percentage and it should end with the percent character. %(per_sample_metric)s Returns ------- %(stc_metric)s Notes ----- These metrics are documented in :footcite:`StenroosHauk2013` and :footcite:`LinEtAl2006a`. .. versionadded:: 1.2 References ---------- .. footbibliography:: """ stc_est, r_true, r_est = _prepare_ppe_sd(stc_true, stc_est, src, threshold) func = partial(_peak_position_error, r_est=r_est, r_true=r_true) metric = _apply(func, stc_true, stc_est, per_sample=per_sample) return metric def _spatial_deviation(p, q, r_est, r_true): q = _abs_col_sum(q) if np.sum(q): q /= np.sum(q) r_true_tile = np.tile(r_true, (r_est.shape[0], 1)) r_diff = r_est - r_true_tile r_diff_norm = np.sum(r_diff**2, axis=1) return np.sqrt(np.dot(q, r_diff_norm)) else: return np.inf @fill_doc def spatial_deviation_error(stc_true, stc_est, src, threshold="50%", per_sample=True): r"""Compute the spatial deviation. The spatial deviation characterizes the spread of the estimate source around the true source. .. math:: SD = \dfrac{\sum_i|s_i|\|r_{i} - r_{true}\|^2}{\sum_i|s_i|}. where :math:`r_{true}` is a true dipole position, :math:`r_i` and :math:`|s_i|` denote respectively the position and amplitude of i-th dipole in source estimate. Threshold is used on estimated source for focusing the metric to strong amplitudes and omitting the low-amplitude values. Parameters ---------- %(stc_true_metric)s %(stc_est_metric)s src : instance of SourceSpaces The source space on which the source estimates are defined. threshold : float | str The threshold to apply to source estimates before computing the recall. If a string the threshold is a percentage and it should end with the percent character. %(per_sample_metric)s Returns ------- %(stc_metric)s Notes ----- These metrics are documented in :footcite:`StenroosHauk2013` and :footcite:`LinEtAl2006a`. .. versionadded:: 1.2 References ---------- .. footbibliography:: """ stc_est, r_true, r_est = _prepare_ppe_sd(stc_true, stc_est, src, threshold) func = partial(_spatial_deviation, r_est=r_est, r_true=r_true) metric = _apply(func, stc_true, stc_est, per_sample=per_sample) return metric