# Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. import numpy as np from ..fixes import _safe_svd from ..forward import is_fixed_orient from ..minimum_norm.inverse import _check_reference, _log_exp_var from ..utils import logger, verbose, warn from .mxne_inverse import ( _check_ori, _compute_residual, _make_dipoles_sparse, _make_sparse_stc, _prepare_gain, _reapply_source_weighting, ) @verbose def _gamma_map_opt( M, G, alpha, maxit=10000, tol=1e-6, update_mode=1, group_size=1, gammas=None, verbose=None, ): """Hierarchical Bayes (Gamma-MAP). Parameters ---------- M : array, shape=(n_sensors, n_times) Observation. G : array, shape=(n_sensors, n_sources) Forward operator. alpha : float Regularization parameter (noise variance). maxit : int Maximum number of iterations. tol : float Tolerance parameter for convergence. group_size : int Number of consecutive sources which use the same gamma. update_mode : int Update mode, 1: MacKay update (default), 3: Modified MacKay update. gammas : array, shape=(n_sources,) Initial values for posterior variances (gammas). If None, a variance of 1.0 is used. %(verbose)s Returns ------- X : array, shape=(n_active, n_times) Estimated source time courses. active_set : array, shape=(n_active,) Indices of active sources. """ G = G.copy() M = M.copy() if gammas is None: gammas = np.ones(G.shape[1], dtype=np.float64) eps = np.finfo(float).eps n_sources = G.shape[1] n_sensors, n_times = M.shape # apply normalization so the numerical values are sane M_normalize_constant = np.linalg.norm(np.dot(M, M.T), ord="fro") M /= np.sqrt(M_normalize_constant) alpha /= M_normalize_constant G_normalize_constant = np.linalg.norm(G, ord=np.inf) G /= G_normalize_constant if n_sources % group_size != 0: raise ValueError( "Number of sources has to be evenly dividable by the group size" ) n_active = n_sources active_set = np.arange(n_sources) gammas_full_old = gammas.copy() if update_mode == 2: denom_fun = np.sqrt else: # do nothing def denom_fun(x): return x last_size = -1 for itno in range(maxit): gammas[np.isnan(gammas)] = 0.0 gidx = np.abs(gammas) > eps active_set = active_set[gidx] gammas = gammas[gidx] # update only active gammas (once set to zero it stays at zero) if n_active > len(active_set): n_active = active_set.size G = G[:, gidx] CM = np.dot(G * gammas[np.newaxis, :], G.T) CM.flat[:: n_sensors + 1] += alpha # Invert CM keeping symmetry U, S, _ = _safe_svd(CM, full_matrices=False) S = S[np.newaxis, :] del CM CMinv = np.dot(U / (S + eps), U.T) CMinvG = np.dot(CMinv, G) A = np.dot(CMinvG.T, M) # mult. w. Diag(gamma) in gamma update if update_mode == 1: # MacKay fixed point update (10) in [1] numer = gammas**2 * np.mean((A * A.conj()).real, axis=1) denom = gammas * np.sum(G * CMinvG, axis=0) elif update_mode == 2: # modified MacKay fixed point update (11) in [1] numer = gammas * np.sqrt(np.mean((A * A.conj()).real, axis=1)) denom = np.sum(G * CMinvG, axis=0) # sqrt is applied below else: raise ValueError("Invalid value for update_mode") if group_size == 1: if denom is None: gammas = numer else: gammas = numer / np.maximum(denom_fun(denom), np.finfo("float").eps) else: numer_comb = np.sum(numer.reshape(-1, group_size), axis=1) if denom is None: gammas_comb = numer_comb else: denom_comb = np.sum(denom.reshape(-1, group_size), axis=1) gammas_comb = numer_comb / denom_fun(denom_comb) gammas = np.repeat(gammas_comb / group_size, group_size) # compute convergence criterion gammas_full = np.zeros(n_sources, dtype=np.float64) gammas_full[active_set] = gammas err = np.sum(np.abs(gammas_full - gammas_full_old)) / np.sum( np.abs(gammas_full_old) ) gammas_full_old = gammas_full breaking = err < tol or n_active == 0 if len(gammas) != last_size or breaking: logger.info( f"Iteration: {itno}\t active set size: {len(gammas)}\t convergence: " f"{err:.3e}" ) last_size = len(gammas) if breaking: break if itno < maxit - 1: logger.info("\nConvergence reached !\n") else: warn("\nConvergence NOT reached !\n") # undo normalization and compute final posterior mean n_const = np.sqrt(M_normalize_constant) / G_normalize_constant x_active = n_const * gammas[:, None] * A return x_active, active_set @verbose def gamma_map( evoked, forward, noise_cov, alpha, loose="auto", depth=0.8, xyz_same_gamma=True, maxit=10000, tol=1e-6, update_mode=1, gammas=None, pca=True, return_residual=False, return_as_dipoles=False, rank=None, pick_ori=None, verbose=None, ): """Hierarchical Bayes (Gamma-MAP) sparse source localization method. Models each source time course using a zero-mean Gaussian prior with an unknown variance (gamma) parameter. During estimation, most gammas are driven to zero, resulting in a sparse source estimate, as in :footcite:`WipfEtAl2007` and :footcite:`WipfNagarajan2009`. For fixed-orientation forward operators, a separate gamma is used for each source time course, while for free-orientation forward operators, the same gamma is used for the three source time courses at each source space point (separate gammas can be used in this case by using xyz_same_gamma=False). Parameters ---------- evoked : instance of Evoked Evoked data to invert. forward : dict Forward operator. noise_cov : instance of Covariance Noise covariance to compute whitener. alpha : float Regularization parameter (noise variance). %(loose)s %(depth)s xyz_same_gamma : bool Use same gamma for xyz current components at each source space point. Recommended for free-orientation forward solutions. maxit : int Maximum number of iterations. tol : float Tolerance parameter for convergence. update_mode : int Update mode, 1: MacKay update (default), 2: Modified MacKay update. gammas : array, shape=(n_sources,) Initial values for posterior variances (gammas). If None, a variance of 1.0 is used. pca : bool If True the rank of the data is reduced to the true dimension. return_residual : bool If True, the residual is returned as an Evoked instance. return_as_dipoles : bool If True, the sources are returned as a list of Dipole instances. %(rank_none)s .. versionadded:: 0.18 %(pick_ori)s %(verbose)s Returns ------- stc : instance of SourceEstimate Source time courses. residual : instance of Evoked The residual a.k.a. data not explained by the sources. Only returned if return_residual is True. References ---------- .. footbibliography:: """ _check_reference(evoked) forward, gain, gain_info, whitener, source_weighting, mask = _prepare_gain( forward, evoked.info, noise_cov, pca, depth, loose, rank ) _check_ori(pick_ori, forward) group_size = 1 if (is_fixed_orient(forward) or not xyz_same_gamma) else 3 # get the data sel = [evoked.ch_names.index(name) for name in gain_info["ch_names"]] M = evoked.data[sel] # whiten the data logger.info("Whitening data matrix.") M = np.dot(whitener, M) # run the optimization X, active_set = _gamma_map_opt( M, gain, alpha, maxit=maxit, tol=tol, update_mode=update_mode, gammas=gammas, group_size=group_size, verbose=verbose, ) if len(active_set) == 0: raise Exception("No active dipoles found. alpha is too big.") M_estimate = gain[:, active_set] @ X # Reapply weights to have correct unit X = _reapply_source_weighting(X, source_weighting, active_set) if return_residual: residual = _compute_residual(forward, evoked, X, active_set, gain_info) if group_size == 1 and not is_fixed_orient(forward): # make sure each source has 3 components idx, offset = divmod(active_set, 3) active_src = np.unique(idx) if len(X) < 3 * len(active_src): X_xyz = np.zeros((len(active_src), 3, X.shape[1]), dtype=X.dtype) idx = np.searchsorted(active_src, idx) X_xyz[idx, offset, :] = X X_xyz.shape = (len(active_src) * 3, X.shape[1]) X = X_xyz active_set = (active_src[:, np.newaxis] * 3 + np.arange(3)).ravel() source_weighting[source_weighting == 0] = 1 # zeros gain_active = gain[:, active_set] / source_weighting[active_set] del source_weighting tmin = evoked.times[0] tstep = 1.0 / evoked.info["sfreq"] if return_as_dipoles: out = _make_dipoles_sparse( X, active_set, forward, tmin, tstep, M, gain_active, active_is_idx=True ) else: out = _make_sparse_stc( X, active_set, forward, tmin, tstep, active_is_idx=True, pick_ori=pick_ori, verbose=verbose, ) _log_exp_var(M, M_estimate, prefix="") logger.info("[done]") if return_residual: out = out, residual return out