1679 lines
52 KiB
Python
1679 lines
52 KiB
Python
# Authors: The MNE-Python contributors.
|
|
# License: BSD-3-Clause
|
|
# Copyright the MNE-Python contributors.
|
|
|
|
import functools
|
|
from math import sqrt
|
|
|
|
import numpy as np
|
|
|
|
from ..time_frequency._stft import istft, stft, stft_norm1, stft_norm2
|
|
from ..utils import (
|
|
_check_option,
|
|
_get_blas_funcs,
|
|
_validate_type,
|
|
logger,
|
|
sum_squared,
|
|
verbose,
|
|
warn,
|
|
)
|
|
from .mxne_debiasing import compute_bias
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def _get_dgemm():
|
|
return _get_blas_funcs(np.float64, "gemm")
|
|
|
|
|
|
def groups_norm2(A, n_orient):
|
|
"""Compute squared L2 norms of groups inplace."""
|
|
n_positions = A.shape[0] // n_orient
|
|
return np.sum(np.power(A, 2, A).reshape(n_positions, -1), axis=1)
|
|
|
|
|
|
def norm_l2inf(A, n_orient, copy=True):
|
|
"""L2-inf norm."""
|
|
if A.size == 0:
|
|
return 0.0
|
|
if copy:
|
|
A = A.copy()
|
|
return sqrt(np.max(groups_norm2(A, n_orient)))
|
|
|
|
|
|
def norm_l21(A, n_orient, copy=True):
|
|
"""L21 norm."""
|
|
if A.size == 0:
|
|
return 0.0
|
|
if copy:
|
|
A = A.copy()
|
|
return np.sum(np.sqrt(groups_norm2(A, n_orient)))
|
|
|
|
|
|
def _primal_l21(M, G, X, active_set, alpha, n_orient):
|
|
"""Primal objective for the mixed-norm inverse problem.
|
|
|
|
See :footcite:`GramfortEtAl2012`.
|
|
|
|
Parameters
|
|
----------
|
|
M : array, shape (n_sensors, n_times)
|
|
The data.
|
|
G : array, shape (n_sensors, n_active)
|
|
The gain matrix a.k.a. lead field.
|
|
X : array, shape (n_active, n_times)
|
|
Sources.
|
|
active_set : array of bool, shape (n_sources,)
|
|
Mask of active sources.
|
|
alpha : float
|
|
The regularization parameter.
|
|
n_orient : int
|
|
Number of dipoles per locations (typically 1 or 3).
|
|
|
|
Returns
|
|
-------
|
|
p_obj : float
|
|
Primal objective.
|
|
R : array, shape (n_sensors, n_times)
|
|
Current residual (M - G * X).
|
|
nR2 : float
|
|
Data-fitting term.
|
|
GX : array, shape (n_sensors, n_times)
|
|
Forward prediction.
|
|
"""
|
|
GX = np.dot(G[:, active_set], X)
|
|
R = M - GX
|
|
penalty = norm_l21(X, n_orient, copy=True)
|
|
nR2 = sum_squared(R)
|
|
p_obj = 0.5 * nR2 + alpha * penalty
|
|
return p_obj, R, nR2, GX
|
|
|
|
|
|
def dgap_l21(M, G, X, active_set, alpha, n_orient):
|
|
"""Duality gap for the mixed norm inverse problem.
|
|
|
|
See :footcite:`GramfortEtAl2012`.
|
|
|
|
Parameters
|
|
----------
|
|
M : array, shape (n_sensors, n_times)
|
|
The data.
|
|
G : array, shape (n_sensors, n_active)
|
|
The gain matrix a.k.a. lead field.
|
|
X : array, shape (n_active, n_times)
|
|
Sources.
|
|
active_set : array of bool, shape (n_sources, )
|
|
Mask of active sources.
|
|
alpha : float
|
|
The regularization parameter.
|
|
n_orient : int
|
|
Number of dipoles per locations (typically 1 or 3).
|
|
|
|
Returns
|
|
-------
|
|
gap : float
|
|
Dual gap.
|
|
p_obj : float
|
|
Primal objective.
|
|
d_obj : float
|
|
Dual objective. gap = p_obj - d_obj.
|
|
R : array, shape (n_sensors, n_times)
|
|
Current residual (M - G * X).
|
|
|
|
References
|
|
----------
|
|
.. footbibilography::
|
|
"""
|
|
p_obj, R, nR2, GX = _primal_l21(M, G, X, active_set, alpha, n_orient)
|
|
dual_norm = norm_l2inf(np.dot(G.T, R), n_orient, copy=False)
|
|
scaling = alpha / dual_norm
|
|
scaling = min(scaling, 1.0)
|
|
d_obj = (scaling - 0.5 * (scaling**2)) * nR2 + scaling * np.sum(R * GX)
|
|
|
|
gap = p_obj - d_obj
|
|
return gap, p_obj, d_obj, R
|
|
|
|
|
|
def _mixed_norm_solver_cd(
|
|
M,
|
|
G,
|
|
alpha,
|
|
lipschitz_constant,
|
|
maxit=10000,
|
|
tol=1e-8,
|
|
init=None,
|
|
n_orient=1,
|
|
dgap_freq=10,
|
|
):
|
|
"""Solve L21 inverse problem with coordinate descent."""
|
|
from sklearn.linear_model import MultiTaskLasso
|
|
|
|
assert M.ndim == G.ndim and M.shape[0] == G.shape[0]
|
|
|
|
clf = MultiTaskLasso(
|
|
alpha=alpha / len(M),
|
|
tol=tol / sum_squared(M),
|
|
fit_intercept=False,
|
|
max_iter=maxit,
|
|
warm_start=True,
|
|
)
|
|
if init is not None:
|
|
clf.coef_ = init.T
|
|
else:
|
|
clf.coef_ = np.zeros((G.shape[1], M.shape[1])).T
|
|
clf.fit(G, M)
|
|
|
|
X = clf.coef_.T
|
|
active_set = np.any(X, axis=1)
|
|
X = X[active_set]
|
|
gap, p_obj, d_obj, _ = dgap_l21(M, G, X, active_set, alpha, n_orient)
|
|
return X, active_set, p_obj
|
|
|
|
|
|
def _mixed_norm_solver_bcd(
|
|
M,
|
|
G,
|
|
alpha,
|
|
lipschitz_constant,
|
|
maxit=200,
|
|
tol=1e-8,
|
|
init=None,
|
|
n_orient=1,
|
|
dgap_freq=10,
|
|
use_accel=True,
|
|
K=5,
|
|
):
|
|
"""Solve L21 inverse problem with block coordinate descent."""
|
|
_, n_times = M.shape
|
|
_, n_sources = G.shape
|
|
n_positions = n_sources // n_orient
|
|
|
|
if init is None:
|
|
X = np.zeros((n_sources, n_times))
|
|
R = M.copy()
|
|
else:
|
|
X = init
|
|
R = M - np.dot(G, X)
|
|
|
|
E = [] # track primal objective function
|
|
highest_d_obj = -np.inf
|
|
active_set = np.zeros(n_sources, dtype=bool) # start with full AS
|
|
|
|
alpha_lc = alpha / lipschitz_constant
|
|
|
|
if use_accel:
|
|
last_K_X = np.empty((K + 1, n_sources, n_times))
|
|
U = np.zeros((K, n_sources * n_times))
|
|
|
|
# First make G fortran for faster access to blocks of columns
|
|
G = np.asfortranarray(G)
|
|
# Ensure these are correct for dgemm
|
|
assert R.dtype == np.float64
|
|
assert G.dtype == np.float64
|
|
one_ovr_lc = 1.0 / lipschitz_constant
|
|
|
|
# assert that all the multiplied matrices are fortran contiguous
|
|
assert X.T.flags.f_contiguous
|
|
assert R.T.flags.f_contiguous
|
|
assert G.flags.f_contiguous
|
|
# storing list of contiguous arrays
|
|
list_G_j_c = []
|
|
for j in range(n_positions):
|
|
idx = slice(j * n_orient, (j + 1) * n_orient)
|
|
list_G_j_c.append(np.ascontiguousarray(G[:, idx]))
|
|
|
|
for i in range(maxit):
|
|
_bcd(G, X, R, active_set, one_ovr_lc, n_orient, alpha_lc, list_G_j_c)
|
|
|
|
if (i + 1) % dgap_freq == 0:
|
|
_, p_obj, d_obj, _ = dgap_l21(
|
|
M, G, X[active_set], active_set, alpha, n_orient
|
|
)
|
|
highest_d_obj = max(d_obj, highest_d_obj)
|
|
gap = p_obj - highest_d_obj
|
|
E.append(p_obj)
|
|
logger.debug(
|
|
"Iteration %d :: p_obj %f :: dgap %f :: n_active %d"
|
|
% (i + 1, p_obj, gap, np.sum(active_set) / n_orient)
|
|
)
|
|
|
|
if gap < tol:
|
|
logger.debug(f"Convergence reached ! (gap: {gap} < {tol})")
|
|
break
|
|
|
|
# using Anderson acceleration of the primal variable for faster
|
|
# convergence
|
|
if use_accel:
|
|
last_K_X[i % (K + 1)] = X
|
|
|
|
if i % (K + 1) == K:
|
|
for k in range(K):
|
|
U[k] = last_K_X[k + 1].ravel() - last_K_X[k].ravel()
|
|
C = U @ U.T
|
|
# at least on ARM64 we can't rely on np.linalg.solve to
|
|
# reliably raise LinAlgError here, so use SVD instead
|
|
# equivalent to:
|
|
# z = np.linalg.solve(C, np.ones(K))
|
|
u, s, _ = np.linalg.svd(C, hermitian=True)
|
|
if s[-1] <= 1e-6 * s[0] or not np.isfinite(s).all():
|
|
logger.debug("Iteration %d: LinAlg Error" % (i + 1))
|
|
continue
|
|
z = ((u * 1 / s) @ u.T).sum(0)
|
|
c = z / z.sum()
|
|
X_acc = np.sum(last_K_X[:-1] * c[:, None, None], axis=0)
|
|
_grp_norm2_acc = groups_norm2(X_acc, n_orient)
|
|
active_set_acc = _grp_norm2_acc != 0
|
|
if n_orient > 1:
|
|
active_set_acc = np.kron(
|
|
active_set_acc, np.ones(n_orient, dtype=bool)
|
|
)
|
|
p_obj = _primal_l21(M, G, X[active_set], active_set, alpha, n_orient)[0]
|
|
p_obj_acc = _primal_l21(
|
|
M, G, X_acc[active_set_acc], active_set_acc, alpha, n_orient
|
|
)[0]
|
|
if p_obj_acc < p_obj:
|
|
X = X_acc
|
|
active_set = active_set_acc
|
|
R = M - G[:, active_set] @ X[active_set]
|
|
|
|
X = X[active_set]
|
|
|
|
return X, active_set, E
|
|
|
|
|
|
def _bcd(G, X, R, active_set, one_ovr_lc, n_orient, alpha_lc, list_G_j_c):
|
|
"""Implement one full pass of BCD.
|
|
|
|
BCD stands for Block Coordinate Descent.
|
|
This function make use of scipy.linalg.get_blas_funcs to speed reasons.
|
|
|
|
Parameters
|
|
----------
|
|
G : array, shape (n_sensors, n_active)
|
|
The gain matrix a.k.a. lead field.
|
|
X : array, shape (n_sources, n_times)
|
|
Sources, modified in place.
|
|
R : array, shape (n_sensors, n_times)
|
|
The residuals: R = M - G @ X, modified in place.
|
|
active_set : array of bool, shape (n_sources, )
|
|
Mask of active sources, modified in place.
|
|
one_ovr_lc : array, shape (n_positions, )
|
|
One over the lipschitz constants.
|
|
n_orient : int
|
|
Number of dipoles per positions (typically 1 or 3).
|
|
n_positions : int
|
|
Number of source positions.
|
|
alpha_lc: array, shape (n_positions, )
|
|
alpha * (Lipschitz constants).
|
|
"""
|
|
X_j_new = np.zeros_like(X[:n_orient, :], order="C")
|
|
dgemm = _get_dgemm()
|
|
|
|
for j, G_j_c in enumerate(list_G_j_c):
|
|
idx = slice(j * n_orient, (j + 1) * n_orient)
|
|
G_j = G[:, idx]
|
|
X_j = X[idx]
|
|
dgemm(
|
|
alpha=one_ovr_lc[j], beta=0.0, a=R.T, b=G_j, c=X_j_new.T, overwrite_c=True
|
|
)
|
|
# X_j_new = G_j.T @ R
|
|
# Mathurin's trick to avoid checking all the entries
|
|
was_non_zero = X_j[0, 0] != 0
|
|
# was_non_zero = np.any(X_j)
|
|
if was_non_zero:
|
|
dgemm(alpha=1.0, beta=1.0, a=X_j.T, b=G_j_c.T, c=R.T, overwrite_c=True)
|
|
# R += np.dot(G_j, X_j)
|
|
X_j_new += X_j
|
|
block_norm = sqrt(sum_squared(X_j_new))
|
|
if block_norm <= alpha_lc[j]:
|
|
X_j.fill(0.0)
|
|
active_set[idx] = False
|
|
else:
|
|
shrink = max(1.0 - alpha_lc[j] / block_norm, 0.0)
|
|
X_j_new *= shrink
|
|
dgemm(alpha=-1.0, beta=1.0, a=X_j_new.T, b=G_j_c.T, c=R.T, overwrite_c=True)
|
|
# R -= np.dot(G_j, X_j_new)
|
|
X_j[:] = X_j_new
|
|
active_set[idx] = True
|
|
|
|
|
|
@verbose
|
|
def mixed_norm_solver(
|
|
M,
|
|
G,
|
|
alpha,
|
|
maxit=3000,
|
|
tol=1e-8,
|
|
verbose=None,
|
|
active_set_size=50,
|
|
debias=True,
|
|
n_orient=1,
|
|
solver="auto",
|
|
return_gap=False,
|
|
dgap_freq=10,
|
|
active_set_init=None,
|
|
X_init=None,
|
|
):
|
|
"""Solve L1/L2 mixed-norm inverse problem with active set strategy.
|
|
|
|
See references :footcite:`GramfortEtAl2012,StrohmeierEtAl2016,
|
|
BertrandEtAl2020`.
|
|
|
|
Parameters
|
|
----------
|
|
M : array, shape (n_sensors, n_times)
|
|
The data.
|
|
G : array, shape (n_sensors, n_dipoles)
|
|
The gain matrix a.k.a. lead field.
|
|
alpha : float
|
|
The regularization parameter. It should be between 0 and 100.
|
|
A value of 100 will lead to an empty active set (no active source).
|
|
maxit : int
|
|
The number of iterations.
|
|
tol : float
|
|
Tolerance on dual gap for convergence checking.
|
|
%(verbose)s
|
|
active_set_size : int
|
|
Size of active set increase at each iteration.
|
|
debias : bool
|
|
Debias source estimates.
|
|
n_orient : int
|
|
The number of orientation (1 : fixed or 3 : free or loose).
|
|
solver : 'cd' | 'bcd' | 'auto'
|
|
The algorithm to use for the optimization. Block Coordinate Descent
|
|
(BCD) uses Anderson acceleration for faster convergence.
|
|
return_gap : bool
|
|
Return final duality gap.
|
|
dgap_freq : int
|
|
The duality gap is computed every dgap_freq iterations of the solver on
|
|
the active set.
|
|
active_set_init : array, shape (n_dipoles,) or None
|
|
The initial active set (boolean array) used at the first iteration.
|
|
If None, the usual active set strategy is applied.
|
|
X_init : array, shape (n_dipoles, n_times) or None
|
|
The initial weight matrix used for warm starting the solver. If None,
|
|
the weights are initialized at zero.
|
|
|
|
Returns
|
|
-------
|
|
X : array, shape (n_active, n_times)
|
|
The source estimates.
|
|
active_set : array, shape (new_active_set_size,)
|
|
The mask of active sources. Note that new_active_set_size is the size
|
|
of the active set after convergence of the solver.
|
|
E : list
|
|
The value of the objective function over the iterations.
|
|
gap : float
|
|
Final duality gap. Returned only if return_gap is True.
|
|
|
|
References
|
|
----------
|
|
.. footbibliography::
|
|
"""
|
|
n_dipoles = G.shape[1]
|
|
n_positions = n_dipoles // n_orient
|
|
_, n_times = M.shape
|
|
alpha_max = norm_l2inf(np.dot(G.T, M), n_orient, copy=False)
|
|
logger.info(f"-- ALPHA MAX : {alpha_max}")
|
|
alpha = float(alpha)
|
|
X = np.zeros((n_dipoles, n_times), dtype=G.dtype)
|
|
|
|
has_sklearn = True
|
|
try:
|
|
from sklearn.linear_model import MultiTaskLasso # noqa: F401
|
|
except ImportError:
|
|
has_sklearn = False
|
|
|
|
_validate_type(solver, str, "solver")
|
|
_check_option("solver", solver, ("cd", "bcd", "auto"))
|
|
if solver == "auto":
|
|
if has_sklearn and (n_orient == 1):
|
|
solver = "cd"
|
|
else:
|
|
solver = "bcd"
|
|
|
|
if solver == "cd":
|
|
if n_orient == 1 and not has_sklearn:
|
|
warn(
|
|
"Scikit-learn >= 0.12 cannot be found. Using block coordinate"
|
|
" descent instead of coordinate descent."
|
|
)
|
|
solver = "bcd"
|
|
if n_orient > 1:
|
|
warn(
|
|
"Coordinate descent is only available for fixed orientation. "
|
|
"Using block coordinate descent instead of coordinate "
|
|
"descent"
|
|
)
|
|
solver = "bcd"
|
|
|
|
if solver == "cd":
|
|
logger.info("Using coordinate descent")
|
|
l21_solver = _mixed_norm_solver_cd
|
|
lc = None
|
|
else:
|
|
assert solver == "bcd"
|
|
logger.info("Using block coordinate descent")
|
|
l21_solver = _mixed_norm_solver_bcd
|
|
G = np.asfortranarray(G)
|
|
if n_orient == 1:
|
|
lc = np.sum(G * G, axis=0)
|
|
else:
|
|
lc = np.empty(n_positions)
|
|
for j in range(n_positions):
|
|
G_tmp = G[:, (j * n_orient) : ((j + 1) * n_orient)]
|
|
lc[j] = np.linalg.norm(np.dot(G_tmp.T, G_tmp), ord=2)
|
|
|
|
if active_set_size is not None:
|
|
E = list()
|
|
highest_d_obj = -np.inf
|
|
if X_init is not None and X_init.shape != (n_dipoles, n_times):
|
|
raise ValueError("Wrong dim for initialized coefficients.")
|
|
active_set = (
|
|
active_set_init
|
|
if active_set_init is not None
|
|
else np.zeros(n_dipoles, dtype=bool)
|
|
)
|
|
idx_large_corr = np.argsort(groups_norm2(np.dot(G.T, M), n_orient))
|
|
new_active_idx = idx_large_corr[-active_set_size:]
|
|
if n_orient > 1:
|
|
new_active_idx = (
|
|
n_orient * new_active_idx[:, None] + np.arange(n_orient)[None, :]
|
|
).ravel()
|
|
active_set[new_active_idx] = True
|
|
as_size = np.sum(active_set)
|
|
gap = np.inf
|
|
for k in range(maxit):
|
|
if solver == "bcd":
|
|
lc_tmp = lc[active_set[::n_orient]]
|
|
elif solver == "cd":
|
|
lc_tmp = None
|
|
else:
|
|
lc_tmp = 1.01 * np.linalg.norm(G[:, active_set], ord=2) ** 2
|
|
X, as_, _ = l21_solver(
|
|
M,
|
|
G[:, active_set],
|
|
alpha,
|
|
lc_tmp,
|
|
maxit=maxit,
|
|
tol=tol,
|
|
init=X_init,
|
|
n_orient=n_orient,
|
|
dgap_freq=dgap_freq,
|
|
)
|
|
active_set[active_set] = as_.copy()
|
|
idx_old_active_set = np.where(active_set)[0]
|
|
|
|
_, p_obj, d_obj, R = dgap_l21(M, G, X, active_set, alpha, n_orient)
|
|
highest_d_obj = max(d_obj, highest_d_obj)
|
|
gap = p_obj - highest_d_obj
|
|
E.append(p_obj)
|
|
logger.info(
|
|
"Iteration %d :: p_obj %f :: dgap %f :: "
|
|
"n_active_start %d :: n_active_end %d"
|
|
% (
|
|
k + 1,
|
|
p_obj,
|
|
gap,
|
|
as_size // n_orient,
|
|
np.sum(active_set) // n_orient,
|
|
)
|
|
)
|
|
if gap < tol:
|
|
logger.info(f"Convergence reached ! (gap: {gap} < {tol})")
|
|
break
|
|
|
|
# add sources if not last iteration
|
|
if k < (maxit - 1):
|
|
idx_large_corr = np.argsort(groups_norm2(np.dot(G.T, R), n_orient))
|
|
new_active_idx = idx_large_corr[-active_set_size:]
|
|
if n_orient > 1:
|
|
new_active_idx = (
|
|
n_orient * new_active_idx[:, None]
|
|
+ np.arange(n_orient)[None, :]
|
|
)
|
|
new_active_idx = new_active_idx.ravel()
|
|
active_set[new_active_idx] = True
|
|
idx_active_set = np.where(active_set)[0]
|
|
as_size = np.sum(active_set)
|
|
X_init = np.zeros((as_size, n_times), dtype=X.dtype)
|
|
idx = np.searchsorted(idx_active_set, idx_old_active_set)
|
|
X_init[idx] = X
|
|
else:
|
|
warn(f"Did NOT converge ! (gap: {gap} > {tol})")
|
|
else:
|
|
X, active_set, E = l21_solver(
|
|
M, G, alpha, lc, maxit=maxit, tol=tol, n_orient=n_orient, init=None
|
|
)
|
|
if return_gap:
|
|
gap = dgap_l21(M, G, X, active_set, alpha, n_orient)[0]
|
|
|
|
if np.any(active_set) and debias:
|
|
bias = compute_bias(M, G[:, active_set], X, n_orient=n_orient)
|
|
X *= bias[:, np.newaxis]
|
|
|
|
logger.info("Final active set size: %s" % (np.sum(active_set) // n_orient))
|
|
|
|
if return_gap:
|
|
return X, active_set, E, gap
|
|
else:
|
|
return X, active_set, E
|
|
|
|
|
|
@verbose
|
|
def iterative_mixed_norm_solver(
|
|
M,
|
|
G,
|
|
alpha,
|
|
n_mxne_iter,
|
|
maxit=3000,
|
|
tol=1e-8,
|
|
verbose=None,
|
|
active_set_size=50,
|
|
debias=True,
|
|
n_orient=1,
|
|
dgap_freq=10,
|
|
solver="auto",
|
|
weight_init=None,
|
|
):
|
|
"""Solve L0.5/L2 mixed-norm inverse problem with active set strategy.
|
|
|
|
See reference :footcite:`StrohmeierEtAl2016`.
|
|
|
|
Parameters
|
|
----------
|
|
M : array, shape (n_sensors, n_times)
|
|
The data.
|
|
G : array, shape (n_sensors, n_dipoles)
|
|
The gain matrix a.k.a. lead field.
|
|
alpha : float
|
|
The regularization parameter. It should be between 0 and 100.
|
|
A value of 100 will lead to an empty active set (no active source).
|
|
n_mxne_iter : int
|
|
The number of MxNE iterations. If > 1, iterative reweighting
|
|
is applied.
|
|
maxit : int
|
|
The number of iterations.
|
|
tol : float
|
|
Tolerance on dual gap for convergence checking.
|
|
%(verbose)s
|
|
active_set_size : int
|
|
Size of active set increase at each iteration.
|
|
debias : bool
|
|
Debias source estimates.
|
|
n_orient : int
|
|
The number of orientation (1 : fixed or 3 : free or loose).
|
|
dgap_freq : int or np.inf
|
|
The duality gap is evaluated every dgap_freq iterations.
|
|
solver : 'cd' | 'bcd' | 'auto'
|
|
The algorithm to use for the optimization.
|
|
weight_init : array, shape (n_dipoles,) or None
|
|
The initial weight used for reweighting the gain matrix. If None, the
|
|
weights are initialized with ones.
|
|
|
|
Returns
|
|
-------
|
|
X : array, shape (n_active, n_times)
|
|
The source estimates.
|
|
active_set : array
|
|
The mask of active sources.
|
|
E : list
|
|
The value of the objective function over the iterations.
|
|
|
|
References
|
|
----------
|
|
.. footbibliography::
|
|
"""
|
|
|
|
def g(w):
|
|
return np.sqrt(np.sqrt(groups_norm2(w.copy(), n_orient)))
|
|
|
|
def gprime(w):
|
|
return 2.0 * np.repeat(g(w), n_orient).ravel()
|
|
|
|
E = list()
|
|
|
|
if weight_init is not None and weight_init.shape != (G.shape[1],):
|
|
raise ValueError(
|
|
f"Wrong dimension for weight initialization. Got {weight_init.shape}. "
|
|
f"Expected {(G.shape[1],)}."
|
|
)
|
|
|
|
weights = weight_init if weight_init is not None else np.ones(G.shape[1])
|
|
active_set = weights != 0
|
|
weights = weights[active_set]
|
|
X = np.zeros((G.shape[1], M.shape[1]))
|
|
|
|
for k in range(n_mxne_iter):
|
|
X0 = X.copy()
|
|
active_set_0 = active_set.copy()
|
|
G_tmp = G[:, active_set] * weights[np.newaxis, :]
|
|
|
|
if active_set_size is not None:
|
|
if np.sum(active_set) > (active_set_size * n_orient):
|
|
X, _active_set, _ = mixed_norm_solver(
|
|
M,
|
|
G_tmp,
|
|
alpha,
|
|
debias=False,
|
|
n_orient=n_orient,
|
|
maxit=maxit,
|
|
tol=tol,
|
|
active_set_size=active_set_size,
|
|
dgap_freq=dgap_freq,
|
|
solver=solver,
|
|
)
|
|
else:
|
|
X, _active_set, _ = mixed_norm_solver(
|
|
M,
|
|
G_tmp,
|
|
alpha,
|
|
debias=False,
|
|
n_orient=n_orient,
|
|
maxit=maxit,
|
|
tol=tol,
|
|
active_set_size=None,
|
|
dgap_freq=dgap_freq,
|
|
solver=solver,
|
|
)
|
|
else:
|
|
X, _active_set, _ = mixed_norm_solver(
|
|
M,
|
|
G_tmp,
|
|
alpha,
|
|
debias=False,
|
|
n_orient=n_orient,
|
|
maxit=maxit,
|
|
tol=tol,
|
|
active_set_size=None,
|
|
dgap_freq=dgap_freq,
|
|
solver=solver,
|
|
)
|
|
|
|
logger.info("active set size %d" % (_active_set.sum() / n_orient))
|
|
|
|
if _active_set.sum() > 0:
|
|
active_set[active_set] = _active_set
|
|
# Reapply weights to have correct unit
|
|
X *= weights[_active_set][:, np.newaxis]
|
|
weights = gprime(X)
|
|
p_obj = 0.5 * np.linalg.norm(
|
|
M - np.dot(G[:, active_set], X), "fro"
|
|
) ** 2.0 + alpha * np.sum(g(X))
|
|
E.append(p_obj)
|
|
|
|
# Check convergence
|
|
if (
|
|
(k >= 1)
|
|
and np.all(active_set == active_set_0)
|
|
and np.all(np.abs(X - X0) < tol)
|
|
):
|
|
print("Convergence reached after %d reweightings!" % k)
|
|
break
|
|
else:
|
|
active_set = np.zeros_like(active_set)
|
|
p_obj = 0.5 * np.linalg.norm(M) ** 2.0
|
|
E.append(p_obj)
|
|
break
|
|
|
|
if np.any(active_set) and debias:
|
|
bias = compute_bias(M, G[:, active_set], X, n_orient=n_orient)
|
|
X *= bias[:, np.newaxis]
|
|
|
|
return X, active_set, E
|
|
|
|
|
|
###############################################################################
|
|
# TF-MxNE
|
|
|
|
|
|
class _Phi:
|
|
"""Have phi stft as callable w/o using a lambda that does not pickle."""
|
|
|
|
def __init__(self, wsize, tstep, n_coefs, n_times):
|
|
self.wsize = np.atleast_1d(wsize)
|
|
self.tstep = np.atleast_1d(tstep)
|
|
self.n_coefs = np.atleast_1d(n_coefs)
|
|
self.n_dicts = len(tstep)
|
|
self.n_freqs = wsize // 2 + 1
|
|
self.n_steps = self.n_coefs // self.n_freqs
|
|
self.n_times = n_times
|
|
# ravel freq+time here
|
|
self.ops = list()
|
|
for ws, ts in zip(self.wsize, self.tstep):
|
|
self.ops.append(
|
|
stft(np.eye(n_times), ws, ts, verbose=False).reshape(n_times, -1)
|
|
)
|
|
|
|
def __call__(self, x): # noqa: D105
|
|
if self.n_dicts == 1:
|
|
return x @ self.ops[0]
|
|
else:
|
|
return np.hstack([x @ op for op in self.ops]) / np.sqrt(self.n_dicts)
|
|
|
|
def norm(self, z, ord=2): # noqa: A002
|
|
"""Squared L2 norm if ord == 2 and L1 norm if order == 1."""
|
|
if ord not in (1, 2):
|
|
raise ValueError(f"Only supported norm order are 1 and 2. Got ord = {ord}")
|
|
stft_norm = stft_norm1 if ord == 1 else stft_norm2
|
|
norm = 0.0
|
|
if len(self.n_coefs) > 1:
|
|
z_ = np.array_split(np.atleast_2d(z), np.cumsum(self.n_coefs)[:-1], axis=1)
|
|
else:
|
|
z_ = [np.atleast_2d(z)]
|
|
for i in range(len(z_)):
|
|
norm += stft_norm(z_[i].reshape(-1, self.n_freqs[i], self.n_steps[i]))
|
|
return norm
|
|
|
|
|
|
class _PhiT:
|
|
"""Have phi.T istft as callable w/o using a lambda that does not pickle."""
|
|
|
|
def __init__(self, tstep, n_freqs, n_steps, n_times):
|
|
self.tstep = tstep
|
|
self.n_freqs = n_freqs
|
|
self.n_steps = n_steps
|
|
self.n_times = n_times
|
|
self.n_dicts = len(tstep) if isinstance(tstep, np.ndarray) else 1
|
|
self.n_coefs = list()
|
|
self.op_re = list()
|
|
self.op_im = list()
|
|
for nf, ns, ts in zip(self.n_freqs, self.n_steps, self.tstep):
|
|
nc = nf * ns
|
|
self.n_coefs.append(nc)
|
|
eye = np.eye(nc).reshape(nf, ns, nf, ns)
|
|
self.op_re.append(istft(eye, ts, n_times).reshape(nc, n_times))
|
|
self.op_im.append(istft(eye * 1j, ts, n_times).reshape(nc, n_times))
|
|
|
|
def __call__(self, z): # noqa: D105
|
|
if self.n_dicts == 1:
|
|
return z.real @ self.op_re[0] + z.imag @ self.op_im[0]
|
|
else:
|
|
x_out = np.zeros((z.shape[0], self.n_times))
|
|
z_ = np.array_split(z, np.cumsum(self.n_coefs)[:-1], axis=1)
|
|
for this_z, op_re, op_im in zip(z_, self.op_re, self.op_im):
|
|
x_out += this_z.real @ op_re + this_z.imag @ op_im
|
|
return x_out / np.sqrt(self.n_dicts)
|
|
|
|
|
|
def norm_l21_tf(Z, phi, n_orient, w_space=None):
|
|
"""L21 norm for TF."""
|
|
if Z.shape[0]:
|
|
l21_norm = np.sqrt(phi.norm(Z, ord=2).reshape(-1, n_orient).sum(axis=1))
|
|
if w_space is not None:
|
|
l21_norm *= w_space
|
|
l21_norm = l21_norm.sum()
|
|
else:
|
|
l21_norm = 0.0
|
|
return l21_norm
|
|
|
|
|
|
def norm_l1_tf(Z, phi, n_orient, w_time):
|
|
"""L1 norm for TF."""
|
|
if Z.shape[0]:
|
|
n_positions = Z.shape[0] // n_orient
|
|
Z_ = np.sqrt(
|
|
np.sum((np.abs(Z) ** 2.0).reshape((n_orient, -1), order="F"), axis=0)
|
|
)
|
|
Z_ = Z_.reshape((n_positions, -1), order="F")
|
|
if w_time is not None:
|
|
Z_ *= w_time
|
|
l1_norm = phi.norm(Z_, ord=1).sum()
|
|
else:
|
|
l1_norm = 0.0
|
|
return l1_norm
|
|
|
|
|
|
def norm_epsilon(Y, l1_ratio, phi, w_space=1.0, w_time=None):
|
|
"""Weighted epsilon norm.
|
|
|
|
The weighted epsilon norm is the dual norm of::
|
|
|
|
w_{space} * (1. - l1_ratio) * ||Y||_2 + l1_ratio * ||Y||_{1, w_{time}}.
|
|
|
|
where `||Y||_{1, w_{time}} = (np.abs(Y) * w_time).sum()`
|
|
|
|
Warning: it takes into account the fact that Y only contains coefficients
|
|
corresponding to the positive frequencies (see `stft_norm2()`): some
|
|
entries will be counted twice. It is also assumed that all entries of both
|
|
Y and w_time are non-negative. See
|
|
:footcite:`NdiayeEtAl2016,BurdakovMerkulov2001`.
|
|
|
|
Parameters
|
|
----------
|
|
Y : array, shape (n_coefs,)
|
|
The input data.
|
|
l1_ratio : float between 0 and 1
|
|
Tradeoff between L2 and L1 regularization. When it is 0, no temporal
|
|
regularization is applied.
|
|
phi : instance of _Phi
|
|
The TF operator.
|
|
w_space : float
|
|
Scalar weight of the L2 norm. By default, it is taken equal to 1.
|
|
w_time : array, shape (n_coefs, ) | None
|
|
Weights of each TF coefficient in the L1 norm. If None, weights equal
|
|
to 1 are used.
|
|
|
|
|
|
Returns
|
|
-------
|
|
nu : float
|
|
The value of the dual norm evaluated at Y.
|
|
|
|
References
|
|
----------
|
|
.. footbibliography::
|
|
"""
|
|
# since the solution is invariant to flipped signs in Y, all entries
|
|
# of Y are assumed positive
|
|
|
|
# Add negative freqs: count all freqs twice except first and last:
|
|
freqs_count = np.full(len(Y), 2)
|
|
for i, fc in enumerate(np.array_split(freqs_count, np.cumsum(phi.n_coefs)[:-1])):
|
|
fc[: phi.n_steps[i]] = 1
|
|
fc[-phi.n_steps[i] :] = 1
|
|
|
|
# exclude 0 weights:
|
|
if w_time is not None:
|
|
nonzero_weights = w_time != 0.0
|
|
Y = Y[nonzero_weights]
|
|
freqs_count = freqs_count[nonzero_weights]
|
|
w_time = w_time[nonzero_weights]
|
|
|
|
norm_inf_Y = np.max(Y / w_time) if w_time is not None else np.max(Y)
|
|
if l1_ratio == 1.0:
|
|
# dual norm of L1 weighted is Linf with inverse weights
|
|
return norm_inf_Y
|
|
elif l1_ratio == 0.0:
|
|
# dual norm of L2 is L2
|
|
return np.sqrt(phi.norm(Y[None, :], ord=2).sum())
|
|
|
|
if norm_inf_Y == 0.0:
|
|
return 0.0
|
|
|
|
# ignore some values of Y by lower bound on dual norm:
|
|
if w_time is None:
|
|
idx = Y > l1_ratio * norm_inf_Y
|
|
else:
|
|
idx = Y > l1_ratio * np.max(
|
|
Y / (w_space * (1.0 - l1_ratio) + l1_ratio * w_time)
|
|
)
|
|
|
|
if idx.sum() == 1:
|
|
return norm_inf_Y
|
|
|
|
# sort both Y / w_time and freqs_count at the same time
|
|
if w_time is not None:
|
|
idx_sort = np.argsort(Y[idx] / w_time[idx])[::-1]
|
|
w_time = w_time[idx][idx_sort]
|
|
else:
|
|
idx_sort = np.argsort(Y[idx])[::-1]
|
|
|
|
Y = Y[idx][idx_sort]
|
|
freqs_count = freqs_count[idx][idx_sort]
|
|
|
|
Y = np.repeat(Y, freqs_count)
|
|
if w_time is not None:
|
|
w_time = np.repeat(w_time, freqs_count)
|
|
|
|
K = Y.shape[0]
|
|
if w_time is None:
|
|
p_sum_Y2 = np.cumsum(Y**2)
|
|
p_sum_w2 = np.arange(1, K + 1)
|
|
p_sum_Yw = np.cumsum(Y)
|
|
upper = p_sum_Y2 / Y**2 - 2.0 * p_sum_Yw / Y + p_sum_w2
|
|
else:
|
|
p_sum_Y2 = np.cumsum(Y**2)
|
|
p_sum_w2 = np.cumsum(w_time**2)
|
|
p_sum_Yw = np.cumsum(Y * w_time)
|
|
upper = p_sum_Y2 / (Y / w_time) ** 2 - 2.0 * p_sum_Yw / (Y / w_time) + p_sum_w2
|
|
upper_greater = np.where(upper > w_space**2 * (1.0 - l1_ratio) ** 2 / l1_ratio**2)[
|
|
0
|
|
]
|
|
|
|
i0 = upper_greater[0] - 1 if upper_greater.size else K - 1
|
|
|
|
p_sum_Y2 = p_sum_Y2[i0]
|
|
p_sum_w2 = p_sum_w2[i0]
|
|
p_sum_Yw = p_sum_Yw[i0]
|
|
|
|
denom = l1_ratio**2 * p_sum_w2 - w_space**2 * (1.0 - l1_ratio) ** 2
|
|
if np.abs(denom) < 1e-10:
|
|
return p_sum_Y2 / (2.0 * l1_ratio * p_sum_Yw)
|
|
else:
|
|
delta = (l1_ratio * p_sum_Yw) ** 2 - p_sum_Y2 * denom
|
|
return (l1_ratio * p_sum_Yw - np.sqrt(delta)) / denom
|
|
|
|
|
|
def norm_epsilon_inf(G, R, phi, l1_ratio, n_orient, w_space=None, w_time=None):
|
|
"""Weighted epsilon-inf norm of phi(np.dot(G.T, R)).
|
|
|
|
Parameters
|
|
----------
|
|
G : array, shape (n_sensors, n_sources)
|
|
Gain matrix a.k.a. lead field.
|
|
R : array, shape (n_sensors, n_times)
|
|
Residual.
|
|
phi : instance of _Phi
|
|
The TF operator.
|
|
l1_ratio : float between 0 and 1
|
|
Parameter controlling the tradeoff between L21 and L1 regularization.
|
|
0 corresponds to an absence of temporal regularization, ie MxNE.
|
|
n_orient : int
|
|
Number of dipoles per location (typically 1 or 3).
|
|
w_space : array, shape (n_positions,) or None.
|
|
Weights for the L2 term of the epsilon norm. If None, weights are
|
|
all equal to 1.
|
|
w_time : array, shape (n_positions, n_coefs) or None
|
|
Weights for the L1 term of the epsilon norm. If None, weights are
|
|
all equal to 1.
|
|
|
|
Returns
|
|
-------
|
|
nu : float
|
|
The maximum value of the epsilon norms over groups of n_orient dipoles
|
|
(consecutive rows of phi(np.dot(G.T, R))).
|
|
"""
|
|
n_positions = G.shape[1] // n_orient
|
|
GTRPhi = np.abs(phi(np.dot(G.T, R)))
|
|
# norm over orientations:
|
|
GTRPhi = GTRPhi.reshape((n_orient, -1), order="F")
|
|
GTRPhi = np.linalg.norm(GTRPhi, axis=0)
|
|
GTRPhi = GTRPhi.reshape((n_positions, -1), order="F")
|
|
nu = 0.0
|
|
for idx in range(n_positions):
|
|
GTRPhi_ = GTRPhi[idx]
|
|
w_t = w_time[idx] if w_time is not None else None
|
|
w_s = w_space[idx] if w_space is not None else 1.0
|
|
norm_eps = norm_epsilon(GTRPhi_, l1_ratio, phi, w_space=w_s, w_time=w_t)
|
|
if norm_eps > nu:
|
|
nu = norm_eps
|
|
|
|
return nu
|
|
|
|
|
|
def dgap_l21l1(
|
|
M,
|
|
G,
|
|
Z,
|
|
active_set,
|
|
alpha_space,
|
|
alpha_time,
|
|
phi,
|
|
phiT,
|
|
n_orient,
|
|
highest_d_obj,
|
|
w_space=None,
|
|
w_time=None,
|
|
):
|
|
"""Duality gap for the time-frequency mixed norm inverse problem.
|
|
|
|
See :footcite:`GramfortEtAl2012,NdiayeEtAl2016`
|
|
|
|
Parameters
|
|
----------
|
|
M : array, shape (n_sensors, n_times)
|
|
The data.
|
|
G : array, shape (n_sensors, n_sources)
|
|
Gain matrix a.k.a. lead field.
|
|
Z : array, shape (n_active, n_coefs)
|
|
Sources in TF domain.
|
|
active_set : array of bool, shape (n_sources, )
|
|
Mask of active sources.
|
|
alpha_space : float
|
|
The spatial regularization parameter.
|
|
alpha_time : float
|
|
The temporal regularization parameter. The higher it is the smoother
|
|
will be the estimated time series.
|
|
phi : instance of _Phi
|
|
The TF operator.
|
|
phiT : instance of _PhiT
|
|
The transpose of the TF operator.
|
|
n_orient : int
|
|
Number of dipoles per locations (typically 1 or 3).
|
|
highest_d_obj : float
|
|
The highest value of the dual objective so far.
|
|
w_space : array, shape (n_positions, )
|
|
Array of spatial weights.
|
|
w_time : array, shape (n_positions, n_coefs)
|
|
Array of TF weights.
|
|
|
|
Returns
|
|
-------
|
|
gap : float
|
|
Dual gap
|
|
p_obj : float
|
|
Primal objective
|
|
d_obj : float
|
|
Dual objective. gap = p_obj - d_obj
|
|
R : array, shape (n_sensors, n_times)
|
|
Current residual (M - G * X)
|
|
|
|
References
|
|
----------
|
|
.. footbibliography::
|
|
"""
|
|
X = phiT(Z)
|
|
GX = np.dot(G[:, active_set], X)
|
|
R = M - GX
|
|
|
|
# some functions need w_time only on active_set, other need it completely
|
|
if w_time is not None:
|
|
w_time_as = w_time[active_set[::n_orient]]
|
|
else:
|
|
w_time_as = None
|
|
if w_space is not None:
|
|
w_space_as = w_space[active_set[::n_orient]]
|
|
else:
|
|
w_space_as = None
|
|
|
|
penaltyl1 = norm_l1_tf(Z, phi, n_orient, w_time_as)
|
|
penaltyl21 = norm_l21_tf(Z, phi, n_orient, w_space_as)
|
|
nR2 = sum_squared(R)
|
|
p_obj = 0.5 * nR2 + alpha_space * penaltyl21 + alpha_time * penaltyl1
|
|
|
|
l1_ratio = alpha_time / (alpha_space + alpha_time)
|
|
dual_norm = norm_epsilon_inf(
|
|
G, R, phi, l1_ratio, n_orient, w_space=w_space, w_time=w_time
|
|
)
|
|
scaling = min(1.0, (alpha_space + alpha_time) / dual_norm)
|
|
|
|
d_obj = (scaling - 0.5 * (scaling**2)) * nR2 + scaling * np.sum(R * GX)
|
|
d_obj = max(d_obj, highest_d_obj)
|
|
|
|
gap = p_obj - d_obj
|
|
return gap, p_obj, d_obj, R
|
|
|
|
|
|
def _tf_mixed_norm_solver_bcd_(
|
|
M,
|
|
G,
|
|
Z,
|
|
active_set,
|
|
candidates,
|
|
alpha_space,
|
|
alpha_time,
|
|
lipschitz_constant,
|
|
phi,
|
|
phiT,
|
|
*,
|
|
w_space=None,
|
|
w_time=None,
|
|
n_orient=1,
|
|
maxit=200,
|
|
tol=1e-8,
|
|
dgap_freq=10,
|
|
perc=None,
|
|
):
|
|
n_sources = G.shape[1]
|
|
n_positions = n_sources // n_orient
|
|
|
|
# First make G fortran for faster access to blocks of columns
|
|
Gd = np.asfortranarray(G)
|
|
G = np.ascontiguousarray(Gd.T.reshape(n_positions, n_orient, -1).transpose(0, 2, 1))
|
|
|
|
R = M.copy() # residual
|
|
active = np.where(active_set[::n_orient])[0]
|
|
for idx in active:
|
|
R -= np.dot(G[idx], phiT(Z[idx]))
|
|
|
|
E = [] # track primal objective function
|
|
|
|
if w_time is None:
|
|
alpha_time_lc = alpha_time / lipschitz_constant
|
|
else:
|
|
alpha_time_lc = alpha_time * w_time / lipschitz_constant[:, None]
|
|
if w_space is None:
|
|
alpha_space_lc = alpha_space / lipschitz_constant
|
|
else:
|
|
alpha_space_lc = alpha_space * w_space / lipschitz_constant
|
|
|
|
converged = False
|
|
d_obj = -np.inf
|
|
|
|
for i in range(maxit):
|
|
for jj in candidates:
|
|
ids = jj * n_orient
|
|
ide = ids + n_orient
|
|
|
|
G_j = G[jj]
|
|
Z_j = Z[jj]
|
|
active_set_j = active_set[ids:ide]
|
|
|
|
was_active = np.any(active_set_j)
|
|
|
|
# gradient step
|
|
GTR = np.dot(G_j.T, R) / lipschitz_constant[jj]
|
|
X_j_new = GTR.copy()
|
|
|
|
if was_active:
|
|
X_j = phiT(Z_j)
|
|
R += np.dot(G_j, X_j)
|
|
X_j_new += X_j
|
|
|
|
rows_norm = np.linalg.norm(X_j_new, "fro")
|
|
if rows_norm <= alpha_space_lc[jj]:
|
|
if was_active:
|
|
Z[jj] = 0.0
|
|
active_set_j[:] = False
|
|
else:
|
|
GTR_phi = phi(GTR)
|
|
if was_active:
|
|
Z_j_new = Z_j + GTR_phi
|
|
else:
|
|
Z_j_new = GTR_phi
|
|
col_norm = np.linalg.norm(Z_j_new, axis=0)
|
|
|
|
if np.all(col_norm <= alpha_time_lc[jj]):
|
|
Z[jj] = 0.0
|
|
active_set_j[:] = False
|
|
else:
|
|
# l1
|
|
shrink = np.maximum(
|
|
1.0
|
|
- alpha_time_lc[jj] / np.maximum(col_norm, alpha_time_lc[jj]),
|
|
0.0,
|
|
)
|
|
if w_time is not None:
|
|
shrink[w_time[jj] == 0.0] = 0.0
|
|
Z_j_new *= shrink[np.newaxis, :]
|
|
|
|
# l21
|
|
shape_init = Z_j_new.shape
|
|
row_norm = np.sqrt(phi.norm(Z_j_new, ord=2).sum())
|
|
if row_norm <= alpha_space_lc[jj]:
|
|
Z[jj] = 0.0
|
|
active_set_j[:] = False
|
|
else:
|
|
shrink = np.maximum(
|
|
1.0
|
|
- alpha_space_lc[jj]
|
|
/ np.maximum(row_norm, alpha_space_lc[jj]),
|
|
0.0,
|
|
)
|
|
Z_j_new *= shrink
|
|
Z[jj] = Z_j_new.reshape(-1, *shape_init[1:]).copy()
|
|
active_set_j[:] = True
|
|
Z_j_phi_T = phiT(Z[jj])
|
|
R -= np.dot(G_j, Z_j_phi_T)
|
|
|
|
if (i + 1) % dgap_freq == 0:
|
|
Zd = np.vstack([Z[pos] for pos in range(n_positions) if np.any(Z[pos])])
|
|
gap, p_obj, d_obj, _ = dgap_l21l1(
|
|
M,
|
|
Gd,
|
|
Zd,
|
|
active_set,
|
|
alpha_space,
|
|
alpha_time,
|
|
phi,
|
|
phiT,
|
|
n_orient,
|
|
d_obj,
|
|
w_space=w_space,
|
|
w_time=w_time,
|
|
)
|
|
converged = gap < tol
|
|
E.append(p_obj)
|
|
logger.info(
|
|
"\n Iteration %d :: n_active %d"
|
|
% (i + 1, np.sum(active_set) / n_orient)
|
|
)
|
|
logger.info(f" dgap {gap:.2e} :: p_obj {p_obj} :: d_obj {d_obj}")
|
|
|
|
if converged:
|
|
break
|
|
|
|
if perc is not None:
|
|
if np.sum(active_set) / float(n_orient) <= perc * n_positions:
|
|
break
|
|
|
|
return Z, active_set, E, converged
|
|
|
|
|
|
def _tf_mixed_norm_solver_bcd_active_set(
|
|
M,
|
|
G,
|
|
alpha_space,
|
|
alpha_time,
|
|
lipschitz_constant,
|
|
phi,
|
|
phiT,
|
|
*,
|
|
Z_init=None,
|
|
w_space=None,
|
|
w_time=None,
|
|
n_orient=1,
|
|
maxit=200,
|
|
tol=1e-8,
|
|
dgap_freq=10,
|
|
):
|
|
n_sensors, n_times = M.shape
|
|
n_sources = G.shape[1]
|
|
n_positions = n_sources // n_orient
|
|
|
|
Z = dict.fromkeys(np.arange(n_positions), 0.0)
|
|
active_set = np.zeros(n_sources, dtype=bool)
|
|
active = []
|
|
if Z_init is not None:
|
|
if Z_init.shape != (n_sources, phi.n_coefs.sum()):
|
|
raise Exception(
|
|
"Z_init must be None or an array with shape (n_sources, n_coefs)."
|
|
)
|
|
for ii in range(n_positions):
|
|
if np.any(Z_init[ii * n_orient : (ii + 1) * n_orient]):
|
|
active_set[ii * n_orient : (ii + 1) * n_orient] = True
|
|
active.append(ii)
|
|
if len(active):
|
|
Z.update(dict(zip(active, np.vsplit(Z_init[active_set], len(active)))))
|
|
|
|
E = []
|
|
candidates = range(n_positions)
|
|
d_obj = -np.inf
|
|
|
|
while True:
|
|
# single BCD pass on all positions:
|
|
Z_init = dict.fromkeys(np.arange(n_positions), 0.0)
|
|
Z_init.update(dict(zip(active, Z.values())))
|
|
Z, active_set, E_tmp, _ = _tf_mixed_norm_solver_bcd_(
|
|
M,
|
|
G,
|
|
Z_init,
|
|
active_set,
|
|
candidates,
|
|
alpha_space,
|
|
alpha_time,
|
|
lipschitz_constant,
|
|
phi,
|
|
phiT,
|
|
w_space=w_space,
|
|
w_time=w_time,
|
|
n_orient=n_orient,
|
|
maxit=1,
|
|
tol=tol,
|
|
perc=None,
|
|
)
|
|
|
|
E += E_tmp
|
|
|
|
# multiple BCD pass on active positions:
|
|
active = np.where(active_set[::n_orient])[0]
|
|
Z_init = dict(zip(range(len(active)), [Z[idx] for idx in active]))
|
|
candidates_ = range(len(active))
|
|
if w_space is not None:
|
|
w_space_as = w_space[active_set[::n_orient]]
|
|
else:
|
|
w_space_as = None
|
|
if w_time is not None:
|
|
w_time_as = w_time[active_set[::n_orient]]
|
|
else:
|
|
w_time_as = None
|
|
|
|
Z, as_, E_tmp, converged = _tf_mixed_norm_solver_bcd_(
|
|
M,
|
|
G[:, active_set],
|
|
Z_init,
|
|
np.ones(len(active) * n_orient, dtype=bool),
|
|
candidates_,
|
|
alpha_space,
|
|
alpha_time,
|
|
lipschitz_constant[active_set[::n_orient]],
|
|
phi,
|
|
phiT,
|
|
w_space=w_space_as,
|
|
w_time=w_time_as,
|
|
n_orient=n_orient,
|
|
maxit=maxit,
|
|
tol=tol,
|
|
dgap_freq=dgap_freq,
|
|
perc=0.5,
|
|
)
|
|
active = np.where(active_set[::n_orient])[0]
|
|
active_set[active_set] = as_.copy()
|
|
E += E_tmp
|
|
|
|
converged = True
|
|
if converged:
|
|
Zd = np.vstack([Z[pos] for pos in range(len(Z)) if np.any(Z[pos])])
|
|
gap, p_obj, d_obj, _ = dgap_l21l1(
|
|
M,
|
|
G,
|
|
Zd,
|
|
active_set,
|
|
alpha_space,
|
|
alpha_time,
|
|
phi,
|
|
phiT,
|
|
n_orient,
|
|
d_obj,
|
|
w_space,
|
|
w_time,
|
|
)
|
|
logger.info(
|
|
"\ndgap %.2e :: p_obj %f :: d_obj %f :: n_active %d"
|
|
% (gap, p_obj, d_obj, np.sum(active_set) / n_orient)
|
|
)
|
|
if gap < tol:
|
|
logger.info("\nConvergence reached!\n")
|
|
break
|
|
|
|
if active_set.sum():
|
|
Z = np.vstack([Z[pos] for pos in range(len(Z)) if np.any(Z[pos])])
|
|
X = phiT(Z)
|
|
else:
|
|
Z = np.zeros((0, phi.n_coefs.sum()), dtype=np.complex128)
|
|
X = np.zeros((0, n_times))
|
|
|
|
return X, Z, active_set, E, gap
|
|
|
|
|
|
@verbose
|
|
def tf_mixed_norm_solver(
|
|
M,
|
|
G,
|
|
alpha_space,
|
|
alpha_time,
|
|
wsize=64,
|
|
tstep=4,
|
|
n_orient=1,
|
|
maxit=200,
|
|
tol=1e-8,
|
|
active_set_size=None,
|
|
debias=True,
|
|
return_gap=False,
|
|
dgap_freq=10,
|
|
verbose=None,
|
|
):
|
|
"""Solve TF L21+L1 inverse solver with BCD and active set approach.
|
|
|
|
See :footcite:`GramfortEtAl2013b,GramfortEtAl2011,BekhtiEtAl2016`.
|
|
|
|
Parameters
|
|
----------
|
|
M : array, shape (n_sensors, n_times)
|
|
The data.
|
|
G : array, shape (n_sensors, n_dipoles)
|
|
The gain matrix a.k.a. lead field.
|
|
alpha_space : float
|
|
The spatial regularization parameter.
|
|
alpha_time : float
|
|
The temporal regularization parameter. The higher it is the smoother
|
|
will be the estimated time series.
|
|
wsize: int or array-like
|
|
Length of the STFT window in samples (must be a multiple of 4).
|
|
If an array is passed, multiple TF dictionaries are used (each having
|
|
its own wsize and tstep) and each entry of wsize must be a multiple
|
|
of 4.
|
|
tstep: int or array-like
|
|
Step between successive windows in samples (must be a multiple of 2,
|
|
a divider of wsize and smaller than wsize/2) (default: wsize/2).
|
|
If an array is passed, multiple TF dictionaries are used (each having
|
|
its own wsize and tstep), and each entry of tstep must be a multiple
|
|
of 2 and divide the corresponding entry of wsize.
|
|
n_orient : int
|
|
The number of orientation (1 : fixed or 3 : free or loose).
|
|
maxit : int
|
|
The number of iterations.
|
|
tol : float
|
|
If absolute difference between estimates at 2 successive iterations
|
|
is lower than tol, the convergence is reached.
|
|
debias : bool
|
|
Debias source estimates.
|
|
return_gap : bool
|
|
Return final duality gap.
|
|
dgap_freq : int or np.inf
|
|
The duality gap is evaluated every dgap_freq iterations.
|
|
%(verbose)s
|
|
|
|
Returns
|
|
-------
|
|
X : array, shape (n_active, n_times)
|
|
The source estimates.
|
|
active_set : array
|
|
The mask of active sources.
|
|
E : list
|
|
The value of the objective function every dgap_freq iteration. If
|
|
log_objective is False or dgap_freq is np.inf, it will be empty.
|
|
gap : float
|
|
Final duality gap. Returned only if return_gap is True.
|
|
|
|
References
|
|
----------
|
|
.. footbibliography::
|
|
"""
|
|
n_sensors, n_times = M.shape
|
|
n_sensors, n_sources = G.shape
|
|
n_positions = n_sources // n_orient
|
|
|
|
tstep = np.atleast_1d(tstep)
|
|
wsize = np.atleast_1d(wsize)
|
|
if len(tstep) != len(wsize):
|
|
raise ValueError(
|
|
"The same number of window sizes and steps must be "
|
|
f"passed. Got tstep = {tstep} and wsize = {wsize}"
|
|
)
|
|
|
|
n_steps = np.ceil(M.shape[1] / tstep.astype(float)).astype(int)
|
|
n_freqs = wsize // 2 + 1
|
|
n_coefs = n_steps * n_freqs
|
|
phi = _Phi(wsize, tstep, n_coefs, n_times)
|
|
phiT = _PhiT(tstep, n_freqs, n_steps, n_times)
|
|
|
|
if n_orient == 1:
|
|
lc = np.sum(G * G, axis=0)
|
|
else:
|
|
lc = np.empty(n_positions)
|
|
for j in range(n_positions):
|
|
G_tmp = G[:, (j * n_orient) : ((j + 1) * n_orient)]
|
|
lc[j] = np.linalg.norm(np.dot(G_tmp.T, G_tmp), ord=2)
|
|
|
|
logger.info("Using block coordinate descent with active set approach")
|
|
X, Z, active_set, E, gap = _tf_mixed_norm_solver_bcd_active_set(
|
|
M,
|
|
G,
|
|
alpha_space,
|
|
alpha_time,
|
|
lc,
|
|
phi,
|
|
phiT,
|
|
Z_init=None,
|
|
n_orient=n_orient,
|
|
maxit=maxit,
|
|
tol=tol,
|
|
dgap_freq=dgap_freq,
|
|
)
|
|
|
|
if np.any(active_set) and debias:
|
|
bias = compute_bias(M, G[:, active_set], X, n_orient=n_orient)
|
|
X *= bias[:, np.newaxis]
|
|
|
|
if return_gap:
|
|
return X, active_set, E, gap
|
|
else:
|
|
return X, active_set, E
|
|
|
|
|
|
def iterative_tf_mixed_norm_solver(
|
|
M,
|
|
G,
|
|
alpha_space,
|
|
alpha_time,
|
|
n_tfmxne_iter,
|
|
wsize=64,
|
|
tstep=4,
|
|
maxit=3000,
|
|
tol=1e-8,
|
|
debias=True,
|
|
n_orient=1,
|
|
dgap_freq=10,
|
|
verbose=None,
|
|
):
|
|
"""Solve TF L0.5/L1 + L0.5 inverse problem with BCD + active set approach.
|
|
|
|
Parameters
|
|
----------
|
|
M: array, shape (n_sensors, n_times)
|
|
The data.
|
|
G: array, shape (n_sensors, n_dipoles)
|
|
The gain matrix a.k.a. lead field.
|
|
alpha_space: float
|
|
The spatial regularization parameter. The higher it is the less there
|
|
will be active sources.
|
|
alpha_time : float
|
|
The temporal regularization parameter. The higher it is the smoother
|
|
will be the estimated time series. 0 means no temporal regularization,
|
|
a.k.a. irMxNE.
|
|
n_tfmxne_iter : int
|
|
Number of TF-MxNE iterations. If > 1, iterative reweighting is applied.
|
|
wsize : int or array-like
|
|
Length of the STFT window in samples (must be a multiple of 4).
|
|
If an array is passed, multiple TF dictionaries are used (each having
|
|
its own wsize and tstep) and each entry of wsize must be a multiple
|
|
of 4.
|
|
tstep : int or array-like
|
|
Step between successive windows in samples (must be a multiple of 2,
|
|
a divider of wsize and smaller than wsize/2) (default: wsize/2).
|
|
If an array is passed, multiple TF dictionaries are used (each having
|
|
its own wsize and tstep), and each entry of tstep must be a multiple
|
|
of 2 and divide the corresponding entry of wsize.
|
|
maxit : int
|
|
The maximum number of iterations for each TF-MxNE problem.
|
|
tol : float
|
|
If absolute difference between estimates at 2 successive iterations
|
|
is lower than tol, the convergence is reached. Also used as criterion
|
|
on duality gap for each TF-MxNE problem.
|
|
debias : bool
|
|
Debias source estimates.
|
|
n_orient : int
|
|
The number of orientation (1 : fixed or 3 : free or loose).
|
|
dgap_freq : int or np.inf
|
|
The duality gap is evaluated every dgap_freq iterations.
|
|
%(verbose)s
|
|
|
|
Returns
|
|
-------
|
|
X : array, shape (n_active, n_times)
|
|
The source estimates.
|
|
active_set : array
|
|
The mask of active sources.
|
|
E : list
|
|
The value of the objective function over iterations.
|
|
"""
|
|
n_sensors, n_times = M.shape
|
|
n_sources = G.shape[1]
|
|
n_positions = n_sources // n_orient
|
|
|
|
tstep = np.atleast_1d(tstep)
|
|
wsize = np.atleast_1d(wsize)
|
|
if len(tstep) != len(wsize):
|
|
raise ValueError(
|
|
"The same number of window sizes and steps must be "
|
|
f"passed. Got tstep = {tstep} and wsize = {wsize}"
|
|
)
|
|
|
|
n_steps = np.ceil(n_times / tstep.astype(float)).astype(int)
|
|
n_freqs = wsize // 2 + 1
|
|
n_coefs = n_steps * n_freqs
|
|
phi = _Phi(wsize, tstep, n_coefs, n_times)
|
|
phiT = _PhiT(tstep, n_freqs, n_steps, n_times)
|
|
|
|
if n_orient == 1:
|
|
lc = np.sum(G * G, axis=0)
|
|
else:
|
|
lc = np.empty(n_positions)
|
|
for j in range(n_positions):
|
|
G_tmp = G[:, (j * n_orient) : ((j + 1) * n_orient)]
|
|
lc[j] = np.linalg.norm(np.dot(G_tmp.T, G_tmp), ord=2)
|
|
|
|
# space and time penalties, and inverse of their derivatives:
|
|
def g_space(Z):
|
|
return np.sqrt(np.sqrt(phi.norm(Z, ord=2).reshape(-1, n_orient).sum(axis=1)))
|
|
|
|
def g_space_prime_inv(Z):
|
|
return 2.0 * g_space(Z)
|
|
|
|
def g_time(Z):
|
|
return np.sqrt(
|
|
np.sqrt(
|
|
np.sum((np.abs(Z) ** 2.0).reshape((n_orient, -1), order="F"), axis=0)
|
|
).reshape((-1, Z.shape[1]), order="F")
|
|
)
|
|
|
|
def g_time_prime_inv(Z):
|
|
return 2.0 * g_time(Z)
|
|
|
|
E = list()
|
|
|
|
active_set = np.ones(n_sources, dtype=bool)
|
|
Z = np.zeros((n_sources, phi.n_coefs.sum()), dtype=np.complex128)
|
|
|
|
for k in range(n_tfmxne_iter):
|
|
active_set_0 = active_set.copy()
|
|
Z0 = Z.copy()
|
|
|
|
if k == 0:
|
|
w_space = None
|
|
w_time = None
|
|
else:
|
|
w_space = 1.0 / g_space_prime_inv(Z)
|
|
w_time = g_time_prime_inv(Z)
|
|
w_time[w_time == 0.0] = -1.0
|
|
w_time = 1.0 / w_time
|
|
w_time[w_time < 0.0] = 0.0
|
|
|
|
X, Z, active_set_, _, _ = _tf_mixed_norm_solver_bcd_active_set(
|
|
M,
|
|
G[:, active_set],
|
|
alpha_space,
|
|
alpha_time,
|
|
lc[active_set[::n_orient]],
|
|
phi,
|
|
phiT,
|
|
Z_init=Z,
|
|
w_space=w_space,
|
|
w_time=w_time,
|
|
n_orient=n_orient,
|
|
maxit=maxit,
|
|
tol=tol,
|
|
dgap_freq=dgap_freq,
|
|
)
|
|
|
|
active_set[active_set] = active_set_
|
|
|
|
if active_set.sum() > 0:
|
|
l21_penalty = np.sum(g_space(Z.copy()))
|
|
l1_penalty = phi.norm(g_time(Z.copy()), ord=1).sum()
|
|
|
|
p_obj = (
|
|
0.5 * np.linalg.norm(M - np.dot(G[:, active_set], X), "fro") ** 2.0
|
|
+ alpha_space * l21_penalty
|
|
+ alpha_time * l1_penalty
|
|
)
|
|
E.append(p_obj)
|
|
|
|
logger.info(
|
|
"Iteration %d: active set size=%d, E=%f"
|
|
% (k + 1, active_set.sum() / n_orient, p_obj)
|
|
)
|
|
|
|
# Check convergence
|
|
if np.array_equal(active_set, active_set_0):
|
|
max_diff = np.amax(np.abs(Z - Z0))
|
|
if max_diff < tol:
|
|
print("Convergence reached after %d reweightings!" % k)
|
|
break
|
|
else:
|
|
p_obj = 0.5 * np.linalg.norm(M) ** 2.0
|
|
E.append(p_obj)
|
|
logger.info(
|
|
"Iteration %d: as_size=%d, E=%f"
|
|
% (k + 1, active_set.sum() / n_orient, p_obj)
|
|
)
|
|
break
|
|
|
|
if debias:
|
|
if active_set.sum() > 0:
|
|
bias = compute_bias(M, G[:, active_set], X, n_orient=n_orient)
|
|
X *= bias[:, np.newaxis]
|
|
|
|
return X, active_set, E
|