200 lines
7.1 KiB
Python
200 lines
7.1 KiB
Python
# Authors: The MNE-Python contributors.
|
|
# License: BSD-3-Clause
|
|
# Copyright the MNE-Python contributors.
|
|
|
|
from functools import partial
|
|
|
|
import numpy as np
|
|
|
|
from ..defaults import _handle_default
|
|
from ..fixes import _safe_svd
|
|
from ..utils import eigh, logger, sqrtm_sym, warn
|
|
|
|
# For the reference implementation of eLORETA (force_equal=False),
|
|
# 0 < loose <= 1 all produce solutions that are (more or less)
|
|
# the same as free orientation (loose=1) and quite different from
|
|
# loose=0 (fixed). If we do force_equal=True, we get a visibly smooth
|
|
# transition from 0->1. This is probably because this mode behaves more like
|
|
# sLORETA and dSPM in that it weights each orientation for a given source
|
|
# uniformly (which is not the case for the reference eLORETA implementation).
|
|
#
|
|
# If we *reapply the orientation prior* after each eLORETA iteration,
|
|
# we can preserve the smooth transition without requiring force_equal=True,
|
|
# which is probably more representative of what eLORETA should do. But this
|
|
# does not produce results that pass the eye test.
|
|
|
|
|
|
def _compute_eloreta(inv, lambda2, options):
|
|
"""Compute the eLORETA solution."""
|
|
from .inverse import _compute_reginv, compute_rank_inverse
|
|
|
|
options = _handle_default("eloreta_options", options)
|
|
eps, max_iter = options["eps"], options["max_iter"]
|
|
force_equal = bool(options["force_equal"]) # None means False
|
|
|
|
# Reassemble the gain matrix (should be fast enough)
|
|
if inv["eigen_leads_weighted"]:
|
|
# We can probably relax this if we ever need to
|
|
raise RuntimeError("eLORETA cannot be computed with weighted eigen leads")
|
|
G = np.dot(
|
|
inv["eigen_fields"]["data"].T * inv["sing"], inv["eigen_leads"]["data"].T
|
|
)
|
|
del inv["eigen_leads"]["data"]
|
|
del inv["eigen_fields"]["data"]
|
|
del inv["sing"]
|
|
G = G.astype(np.float64)
|
|
n_nzero = compute_rank_inverse(inv)
|
|
G /= np.sqrt(inv["source_cov"]["data"])
|
|
# restore orientation prior
|
|
source_std = np.ones(G.shape[1])
|
|
if inv["orient_prior"] is not None:
|
|
source_std *= np.sqrt(inv["orient_prior"]["data"])
|
|
G *= source_std
|
|
# We do not multiply by the depth prior, as eLORETA should compensate for
|
|
# depth bias.
|
|
n_src = inv["nsource"]
|
|
n_chan, n_orient = G.shape
|
|
n_orient //= n_src
|
|
assert n_orient in (1, 3)
|
|
logger.info(" Computing optimized source covariance (eLORETA)...")
|
|
if n_orient == 3:
|
|
logger.info(
|
|
f" Using {'uniform' if force_equal else 'independent'} "
|
|
"orientation weights"
|
|
)
|
|
# src, sens, 3
|
|
G_3 = _get_G_3(G, n_orient)
|
|
if n_orient != 1 and not force_equal:
|
|
# Outer product
|
|
R_prior = source_std.reshape(n_src, 1, 3) * source_std.reshape(n_src, 3, 1)
|
|
else:
|
|
R_prior = source_std**2
|
|
|
|
# The following was adapted under BSD license by permission of Guido Nolte
|
|
if force_equal or n_orient == 1:
|
|
R_shape = (n_src * n_orient,)
|
|
R = np.ones(R_shape)
|
|
else:
|
|
R_shape = (n_src, n_orient, n_orient)
|
|
R = np.empty(R_shape)
|
|
R[:] = np.eye(n_orient)[np.newaxis]
|
|
R *= R_prior
|
|
_this_normalize_R = partial(
|
|
_normalize_R,
|
|
n_nzero=n_nzero,
|
|
force_equal=force_equal,
|
|
n_src=n_src,
|
|
n_orient=n_orient,
|
|
)
|
|
G_R_Gt = _this_normalize_R(G, R, G_3)
|
|
extra = " (this make take a while)" if n_orient == 3 else ""
|
|
logger.info(f" Fitting up to {max_iter} iterations{extra}...")
|
|
for kk in range(max_iter):
|
|
# 1. Compute inverse of the weights (stabilized) and C
|
|
s, u = eigh(G_R_Gt)
|
|
s = abs(s)
|
|
sidx = np.argsort(s)[::-1][:n_nzero]
|
|
s, u = s[sidx], u[:, sidx]
|
|
with np.errstate(invalid="ignore"):
|
|
s = np.where(s > 0, 1 / (s + lambda2), 0)
|
|
N = np.dot(u * s, u.T)
|
|
del s
|
|
|
|
# Update the weights
|
|
R_last = R.copy()
|
|
if n_orient == 1:
|
|
R[:] = 1.0 / np.sqrt((np.dot(N, G) * G).sum(0))
|
|
else:
|
|
M = np.matmul(np.matmul(G_3, N[np.newaxis]), G_3.swapaxes(-2, -1))
|
|
if force_equal:
|
|
_, s = sqrtm_sym(M, inv=True)
|
|
R[:] = np.repeat(1.0 / np.mean(s, axis=-1), 3)
|
|
else:
|
|
R[:], _ = sqrtm_sym(M, inv=True)
|
|
R *= R_prior # reapply our prior, eLORETA undoes it
|
|
G_R_Gt = _this_normalize_R(G, R, G_3)
|
|
|
|
# Check for weight convergence
|
|
delta = np.linalg.norm(R.ravel() - R_last.ravel()) / np.linalg.norm(
|
|
R_last.ravel()
|
|
)
|
|
logger.debug(
|
|
f" Iteration {kk + 1} / {max_iter} ...{extra} ({delta:0.1e})"
|
|
)
|
|
if delta < eps:
|
|
logger.info(
|
|
f" Converged on iteration {kk} ({delta:.2g} < {eps:.2g})"
|
|
)
|
|
break
|
|
else:
|
|
warn(f"eLORETA weight fitting did not converge (>= {eps})")
|
|
del G_R_Gt
|
|
logger.info(" Updating inverse with weighted eigen leads")
|
|
G /= source_std # undo our biasing
|
|
G_3 = _get_G_3(G, n_orient)
|
|
_this_normalize_R(G, R, G_3)
|
|
del G_3
|
|
if n_orient == 1 or force_equal:
|
|
R_sqrt = np.sqrt(R)
|
|
else:
|
|
R_sqrt = sqrtm_sym(R)[0]
|
|
assert R_sqrt.shape == R_shape
|
|
A = _R_sqrt_mult(G, R_sqrt)
|
|
del R, G # the rest will be done in terms of R_sqrt and A
|
|
eigen_fields, sing, eigen_leads = _safe_svd(A, full_matrices=False)
|
|
del A
|
|
inv["sing"] = sing
|
|
inv["reginv"] = _compute_reginv(inv, lambda2)
|
|
inv["eigen_leads_weighted"] = True
|
|
inv["eigen_leads"]["data"] = _R_sqrt_mult(eigen_leads, R_sqrt).T
|
|
inv["eigen_fields"]["data"] = eigen_fields.T
|
|
# XXX in theory we should set inv['source_cov'] properly.
|
|
# For fixed ori (or free ori with force_equal=True), we can as these
|
|
# are diagonal matrices. But for free ori without force_equal, it's a
|
|
# block diagonal 3x3 and we have no efficient way of storing this (and
|
|
# storing a covariance matrix with (20484 * 3) ** 2 elements is not going
|
|
# to work. So let's just set to nan for now.
|
|
# It's not used downstream anyway now that we set
|
|
# eigen_leads_weighted = True.
|
|
inv["source_cov"]["data"].fill(np.nan)
|
|
logger.info("[done]")
|
|
|
|
|
|
def _normalize_R(G, R, G_3, n_nzero, force_equal, n_src, n_orient):
|
|
"""Normalize R so that lambda2 is consistent."""
|
|
if n_orient == 1 or force_equal:
|
|
R_Gt = R[:, np.newaxis] * G.T
|
|
else:
|
|
R_Gt = np.matmul(R, G_3).reshape(n_src * 3, -1)
|
|
G_R_Gt = G @ R_Gt
|
|
norm = np.trace(G_R_Gt) / n_nzero
|
|
G_R_Gt /= norm
|
|
R /= norm
|
|
return G_R_Gt
|
|
|
|
|
|
def _get_G_3(G, n_orient):
|
|
if n_orient == 1:
|
|
return None
|
|
else:
|
|
return G.reshape(G.shape[0], -1, n_orient).transpose(1, 2, 0)
|
|
|
|
|
|
def _R_sqrt_mult(other, R_sqrt):
|
|
"""Do other @ R ** 0.5."""
|
|
if R_sqrt.ndim == 1:
|
|
assert other.shape[1] == R_sqrt.size
|
|
out = R_sqrt * other
|
|
else:
|
|
assert R_sqrt.shape[1:3] == (3, 3)
|
|
assert other.shape[1] == np.prod(R_sqrt.shape[:2])
|
|
assert other.ndim == 2
|
|
n_src = R_sqrt.shape[0]
|
|
n_chan = other.shape[0]
|
|
out = (
|
|
np.matmul(R_sqrt, other.reshape(n_chan, n_src, 3).transpose(1, 2, 0))
|
|
.reshape(n_src * 3, n_chan)
|
|
.T
|
|
)
|
|
return out
|