# Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. from functools import partial import numpy as np from ..defaults import _handle_default from ..fixes import _safe_svd from ..utils import eigh, logger, sqrtm_sym, warn # For the reference implementation of eLORETA (force_equal=False), # 0 < loose <= 1 all produce solutions that are (more or less) # the same as free orientation (loose=1) and quite different from # loose=0 (fixed). If we do force_equal=True, we get a visibly smooth # transition from 0->1. This is probably because this mode behaves more like # sLORETA and dSPM in that it weights each orientation for a given source # uniformly (which is not the case for the reference eLORETA implementation). # # If we *reapply the orientation prior* after each eLORETA iteration, # we can preserve the smooth transition without requiring force_equal=True, # which is probably more representative of what eLORETA should do. But this # does not produce results that pass the eye test. def _compute_eloreta(inv, lambda2, options): """Compute the eLORETA solution.""" from .inverse import _compute_reginv, compute_rank_inverse options = _handle_default("eloreta_options", options) eps, max_iter = options["eps"], options["max_iter"] force_equal = bool(options["force_equal"]) # None means False # Reassemble the gain matrix (should be fast enough) if inv["eigen_leads_weighted"]: # We can probably relax this if we ever need to raise RuntimeError("eLORETA cannot be computed with weighted eigen leads") G = np.dot( inv["eigen_fields"]["data"].T * inv["sing"], inv["eigen_leads"]["data"].T ) del inv["eigen_leads"]["data"] del inv["eigen_fields"]["data"] del inv["sing"] G = G.astype(np.float64) n_nzero = compute_rank_inverse(inv) G /= np.sqrt(inv["source_cov"]["data"]) # restore orientation prior source_std = np.ones(G.shape[1]) if inv["orient_prior"] is not None: source_std *= np.sqrt(inv["orient_prior"]["data"]) G *= source_std # We do not multiply by the depth prior, as eLORETA should compensate for # depth bias. n_src = inv["nsource"] n_chan, n_orient = G.shape n_orient //= n_src assert n_orient in (1, 3) logger.info(" Computing optimized source covariance (eLORETA)...") if n_orient == 3: logger.info( f" Using {'uniform' if force_equal else 'independent'} " "orientation weights" ) # src, sens, 3 G_3 = _get_G_3(G, n_orient) if n_orient != 1 and not force_equal: # Outer product R_prior = source_std.reshape(n_src, 1, 3) * source_std.reshape(n_src, 3, 1) else: R_prior = source_std**2 # The following was adapted under BSD license by permission of Guido Nolte if force_equal or n_orient == 1: R_shape = (n_src * n_orient,) R = np.ones(R_shape) else: R_shape = (n_src, n_orient, n_orient) R = np.empty(R_shape) R[:] = np.eye(n_orient)[np.newaxis] R *= R_prior _this_normalize_R = partial( _normalize_R, n_nzero=n_nzero, force_equal=force_equal, n_src=n_src, n_orient=n_orient, ) G_R_Gt = _this_normalize_R(G, R, G_3) extra = " (this make take a while)" if n_orient == 3 else "" logger.info(f" Fitting up to {max_iter} iterations{extra}...") for kk in range(max_iter): # 1. Compute inverse of the weights (stabilized) and C s, u = eigh(G_R_Gt) s = abs(s) sidx = np.argsort(s)[::-1][:n_nzero] s, u = s[sidx], u[:, sidx] with np.errstate(invalid="ignore"): s = np.where(s > 0, 1 / (s + lambda2), 0) N = np.dot(u * s, u.T) del s # Update the weights R_last = R.copy() if n_orient == 1: R[:] = 1.0 / np.sqrt((np.dot(N, G) * G).sum(0)) else: M = np.matmul(np.matmul(G_3, N[np.newaxis]), G_3.swapaxes(-2, -1)) if force_equal: _, s = sqrtm_sym(M, inv=True) R[:] = np.repeat(1.0 / np.mean(s, axis=-1), 3) else: R[:], _ = sqrtm_sym(M, inv=True) R *= R_prior # reapply our prior, eLORETA undoes it G_R_Gt = _this_normalize_R(G, R, G_3) # Check for weight convergence delta = np.linalg.norm(R.ravel() - R_last.ravel()) / np.linalg.norm( R_last.ravel() ) logger.debug( f" Iteration {kk + 1} / {max_iter} ...{extra} ({delta:0.1e})" ) if delta < eps: logger.info( f" Converged on iteration {kk} ({delta:.2g} < {eps:.2g})" ) break else: warn(f"eLORETA weight fitting did not converge (>= {eps})") del G_R_Gt logger.info(" Updating inverse with weighted eigen leads") G /= source_std # undo our biasing G_3 = _get_G_3(G, n_orient) _this_normalize_R(G, R, G_3) del G_3 if n_orient == 1 or force_equal: R_sqrt = np.sqrt(R) else: R_sqrt = sqrtm_sym(R)[0] assert R_sqrt.shape == R_shape A = _R_sqrt_mult(G, R_sqrt) del R, G # the rest will be done in terms of R_sqrt and A eigen_fields, sing, eigen_leads = _safe_svd(A, full_matrices=False) del A inv["sing"] = sing inv["reginv"] = _compute_reginv(inv, lambda2) inv["eigen_leads_weighted"] = True inv["eigen_leads"]["data"] = _R_sqrt_mult(eigen_leads, R_sqrt).T inv["eigen_fields"]["data"] = eigen_fields.T # XXX in theory we should set inv['source_cov'] properly. # For fixed ori (or free ori with force_equal=True), we can as these # are diagonal matrices. But for free ori without force_equal, it's a # block diagonal 3x3 and we have no efficient way of storing this (and # storing a covariance matrix with (20484 * 3) ** 2 elements is not going # to work. So let's just set to nan for now. # It's not used downstream anyway now that we set # eigen_leads_weighted = True. inv["source_cov"]["data"].fill(np.nan) logger.info("[done]") def _normalize_R(G, R, G_3, n_nzero, force_equal, n_src, n_orient): """Normalize R so that lambda2 is consistent.""" if n_orient == 1 or force_equal: R_Gt = R[:, np.newaxis] * G.T else: R_Gt = np.matmul(R, G_3).reshape(n_src * 3, -1) G_R_Gt = G @ R_Gt norm = np.trace(G_R_Gt) / n_nzero G_R_Gt /= norm R /= norm return G_R_Gt def _get_G_3(G, n_orient): if n_orient == 1: return None else: return G.reshape(G.shape[0], -1, n_orient).transpose(1, 2, 0) def _R_sqrt_mult(other, R_sqrt): """Do other @ R ** 0.5.""" if R_sqrt.ndim == 1: assert other.shape[1] == R_sqrt.size out = R_sqrt * other else: assert R_sqrt.shape[1:3] == (3, 3) assert other.shape[1] == np.prod(R_sqrt.shape[:2]) assert other.ndim == 2 n_src = R_sqrt.shape[0] n_chan = other.shape[0] out = ( np.matmul(R_sqrt, other.reshape(n_chan, n_src, 3).transpose(1, 2, 0)) .reshape(n_src * 3, n_chan) .T ) return out