316 lines
9.8 KiB
Python
316 lines
9.8 KiB
Python
"""Compute a Recursively Applied and Projected MUltiple Signal Classification (RAP-MUSIC).""" # noqa
|
|
|
|
# Authors: The MNE-Python contributors.
|
|
# License: BSD-3-Clause
|
|
# Copyright the MNE-Python contributors.
|
|
|
|
import numpy as np
|
|
from scipy import linalg
|
|
|
|
from .._fiff.pick import pick_channels_forward, pick_info
|
|
from ..fixes import _safe_svd
|
|
from ..forward import convert_forward_solution, is_fixed_orient
|
|
from ..inverse_sparse.mxne_inverse import _make_dipoles_sparse
|
|
from ..minimum_norm.inverse import _log_exp_var
|
|
from ..utils import _check_info_inv, fill_doc, logger, verbose
|
|
from ._compute_beamformer import _prepare_beamformer_input
|
|
|
|
|
|
@fill_doc
|
|
def _apply_rap_music(
|
|
data, info, times, forward, noise_cov, n_dipoles=2, picks=None, use_trap=False
|
|
):
|
|
"""RAP-MUSIC or TRAP-MUSIC for evoked data.
|
|
|
|
Parameters
|
|
----------
|
|
data : array, shape (n_channels, n_times)
|
|
Evoked data.
|
|
%(info_not_none)s
|
|
times : array
|
|
Times.
|
|
forward : instance of Forward
|
|
Forward operator.
|
|
noise_cov : instance of Covariance
|
|
The noise covariance.
|
|
n_dipoles : int
|
|
The number of dipoles to estimate. The default value is 2.
|
|
picks : list of int
|
|
Caller ensures this is a list of int.
|
|
use_trap : bool
|
|
Use the TRAP-MUSIC variant if True (default False).
|
|
|
|
Returns
|
|
-------
|
|
dipoles : list of instances of Dipole
|
|
The dipole fits.
|
|
explained_data : array | None
|
|
Data explained by the dipoles using a least square fitting with the
|
|
selected active dipoles and their estimated orientation.
|
|
"""
|
|
info = pick_info(info, picks)
|
|
del picks
|
|
# things are much simpler if we avoid surface orientation
|
|
align = forward["source_nn"].copy()
|
|
if forward["surf_ori"] and not is_fixed_orient(forward):
|
|
forward = convert_forward_solution(forward, surf_ori=False)
|
|
is_free_ori, info, _, _, G, whitener, _, _ = _prepare_beamformer_input(
|
|
info, forward, noise_cov=noise_cov, rank=None
|
|
)
|
|
forward = pick_channels_forward(forward, info["ch_names"], ordered=True)
|
|
del info
|
|
|
|
# whiten the data (leadfield already whitened)
|
|
M = np.dot(whitener, data)
|
|
del data
|
|
|
|
_, eig_vectors = linalg.eigh(np.dot(M, M.T))
|
|
phi_sig = eig_vectors[:, -n_dipoles:]
|
|
|
|
n_orient = 3 if is_free_ori else 1
|
|
G.shape = (G.shape[0], -1, n_orient)
|
|
gain = forward["sol"]["data"].copy()
|
|
gain.shape = G.shape
|
|
n_channels = G.shape[0]
|
|
A = np.empty((n_channels, n_dipoles))
|
|
gain_dip = np.empty((n_channels, n_dipoles))
|
|
oris = np.empty((n_dipoles, 3))
|
|
poss = np.empty((n_dipoles, 3))
|
|
|
|
G_proj = G.copy()
|
|
phi_sig_proj = phi_sig.copy()
|
|
|
|
idxs = list()
|
|
for k in range(n_dipoles):
|
|
subcorr_max = -1.0
|
|
source_idx, source_ori, source_pos = 0, [0, 0, 0], [0, 0, 0]
|
|
for i_source in range(G.shape[1]):
|
|
Gk = G_proj[:, i_source]
|
|
subcorr, ori = _compute_subcorr(Gk, phi_sig_proj)
|
|
if subcorr > subcorr_max:
|
|
subcorr_max = subcorr
|
|
source_idx = i_source
|
|
source_ori = ori
|
|
source_pos = forward["source_rr"][i_source]
|
|
if n_orient == 3 and align is not None:
|
|
surf_normal = forward["source_nn"][3 * i_source + 2]
|
|
# make sure ori is aligned to the surface orientation
|
|
source_ori *= np.sign(source_ori @ surf_normal) or 1.0
|
|
if n_orient == 1:
|
|
source_ori = forward["source_nn"][i_source]
|
|
|
|
idxs.append(source_idx)
|
|
if n_orient == 3:
|
|
Ak = np.dot(G[:, source_idx], source_ori)
|
|
else:
|
|
Ak = G[:, source_idx, 0]
|
|
A[:, k] = Ak
|
|
oris[k] = source_ori
|
|
poss[k] = source_pos
|
|
|
|
logger.info(f"source {k + 1} found: p = {source_idx}")
|
|
if n_orient == 3:
|
|
logger.info("ori = {} {} {}".format(*tuple(oris[k])))
|
|
|
|
projection = _compute_proj(A[:, : k + 1])
|
|
G_proj = np.einsum("ab,bso->aso", projection, G)
|
|
phi_sig_proj = np.dot(projection, phi_sig)
|
|
if use_trap:
|
|
phi_sig_proj = phi_sig_proj[:, -(n_dipoles - k) :]
|
|
del G, G_proj
|
|
|
|
sol = linalg.lstsq(A, M)[0]
|
|
if n_orient == 3:
|
|
X = sol[:, np.newaxis] * oris[:, :, np.newaxis]
|
|
X.shape = (-1, len(times))
|
|
else:
|
|
X = sol
|
|
|
|
gain_active = gain[:, idxs]
|
|
if n_orient == 3:
|
|
gain_dip = (oris * gain_active).sum(-1)
|
|
idxs = np.array(idxs)
|
|
active_set = np.array([[3 * idxs, 3 * idxs + 1, 3 * idxs + 2]]).T.ravel()
|
|
else:
|
|
gain_dip = gain_active[:, :, 0]
|
|
active_set = idxs
|
|
gain_active = whitener @ gain_active.reshape(gain.shape[0], -1)
|
|
assert gain_active.shape == (n_channels, X.shape[0])
|
|
|
|
explained_data = gain_dip @ sol
|
|
M_estimate = whitener @ explained_data
|
|
_log_exp_var(M, M_estimate)
|
|
tstep = np.median(np.diff(times)) if len(times) > 1 else 1.0
|
|
dipoles = _make_dipoles_sparse(
|
|
X, active_set, forward, times[0], tstep, M, gain_active, active_is_idx=True
|
|
)
|
|
for dipole, ori in zip(dipoles, oris):
|
|
signs = np.sign((dipole.ori * ori).sum(-1, keepdims=True))
|
|
dipole.ori *= signs
|
|
dipole.amplitude *= signs[:, 0]
|
|
logger.info("[done]")
|
|
return dipoles, explained_data
|
|
|
|
|
|
def _compute_subcorr(G, phi_sig):
|
|
"""Compute the subspace correlation."""
|
|
Ug, Sg, Vg = _safe_svd(G, full_matrices=False)
|
|
# Now we look at the actual rank of the forward fields
|
|
# in G and handle the fact that it might be rank defficient
|
|
# eg. when using MEG and a sphere model for which the
|
|
# radial component will be truly 0.
|
|
rank = np.sum(Sg > (Sg[0] * 1e-6))
|
|
if rank == 0:
|
|
return 0, np.zeros(len(G))
|
|
rank = max(rank, 2) # rank cannot be 1
|
|
Ug, Sg, Vg = Ug[:, :rank], Sg[:rank], Vg[:rank]
|
|
tmp = np.dot(Ug.T.conjugate(), phi_sig)
|
|
Uc, Sc, _ = _safe_svd(tmp, full_matrices=False)
|
|
X = np.dot(Vg.T / Sg[None, :], Uc[:, 0]) # subcorr
|
|
return Sc[0], X / np.linalg.norm(X)
|
|
|
|
|
|
def _compute_proj(A):
|
|
"""Compute the orthogonal projection operation for a manifold vector A."""
|
|
U, _, _ = _safe_svd(A, full_matrices=False)
|
|
return np.identity(A.shape[0]) - np.dot(U, U.T.conjugate())
|
|
|
|
|
|
def _rap_music(evoked, forward, noise_cov, n_dipoles, return_residual, use_trap):
|
|
"""RAP-/TRAP-MUSIC implementation."""
|
|
info = evoked.info
|
|
data = evoked.data
|
|
times = evoked.times
|
|
|
|
picks = _check_info_inv(info, forward, data_cov=None, noise_cov=noise_cov)
|
|
|
|
data = data[picks]
|
|
|
|
dipoles, explained_data = _apply_rap_music(
|
|
data, info, times, forward, noise_cov, n_dipoles, picks, use_trap
|
|
)
|
|
|
|
if return_residual:
|
|
residual = evoked.copy().pick([info["ch_names"][p] for p in picks])
|
|
residual.data -= explained_data
|
|
active_projs = [p for p in residual.info["projs"] if p["active"]]
|
|
for p in active_projs:
|
|
p["active"] = False
|
|
residual.add_proj(active_projs, remove_existing=True)
|
|
residual.apply_proj()
|
|
return dipoles, residual
|
|
else:
|
|
return dipoles
|
|
|
|
|
|
@verbose
|
|
def rap_music(
|
|
evoked,
|
|
forward,
|
|
noise_cov,
|
|
n_dipoles=5,
|
|
return_residual=False,
|
|
*,
|
|
verbose=None,
|
|
):
|
|
"""RAP-MUSIC source localization method.
|
|
|
|
Compute Recursively Applied and Projected MUltiple SIgnal Classification
|
|
(RAP-MUSIC) :footcite:`MosherLeahy1999,MosherLeahy1996` on evoked data.
|
|
|
|
.. note:: The goodness of fit (GOF) of all the returned dipoles is the
|
|
same and corresponds to the GOF of the full set of dipoles.
|
|
|
|
Parameters
|
|
----------
|
|
evoked : instance of Evoked
|
|
Evoked data to localize.
|
|
forward : instance of Forward
|
|
Forward operator.
|
|
noise_cov : instance of Covariance
|
|
The noise covariance.
|
|
n_dipoles : int
|
|
The number of dipoles to look for. The default value is 5.
|
|
return_residual : bool
|
|
If True, the residual is returned as an Evoked instance.
|
|
%(verbose)s
|
|
|
|
Returns
|
|
-------
|
|
dipoles : list of instance of Dipole
|
|
The dipole fits.
|
|
residual : instance of Evoked
|
|
The residual a.k.a. data not explained by the dipoles.
|
|
Only returned if return_residual is True.
|
|
|
|
See Also
|
|
--------
|
|
mne.fit_dipole
|
|
mne.beamformer.trap_music
|
|
|
|
Notes
|
|
-----
|
|
.. versionadded:: 0.9.0
|
|
|
|
References
|
|
----------
|
|
.. footbibliography::
|
|
"""
|
|
return _rap_music(evoked, forward, noise_cov, n_dipoles, return_residual, False)
|
|
|
|
|
|
@verbose
|
|
def trap_music(
|
|
evoked,
|
|
forward,
|
|
noise_cov,
|
|
n_dipoles=5,
|
|
return_residual=False,
|
|
*,
|
|
verbose=None,
|
|
):
|
|
"""TRAP-MUSIC source localization method.
|
|
|
|
Compute Truncated Recursively Applied and Projected MUltiple SIgnal Classification
|
|
(TRAP-MUSIC) :footcite:`Makela2018` on evoked data.
|
|
|
|
.. note:: The goodness of fit (GOF) of all the returned dipoles is the
|
|
same and corresponds to the GOF of the full set of dipoles.
|
|
|
|
Parameters
|
|
----------
|
|
evoked : instance of Evoked
|
|
Evoked data to localize.
|
|
forward : instance of Forward
|
|
Forward operator.
|
|
noise_cov : instance of Covariance
|
|
The noise covariance.
|
|
n_dipoles : int
|
|
The number of dipoles to look for. The default value is 5.
|
|
return_residual : bool
|
|
If True, the residual is returned as an Evoked instance.
|
|
%(verbose)s
|
|
|
|
Returns
|
|
-------
|
|
dipoles : list of instance of Dipole
|
|
The dipole fits.
|
|
residual : instance of Evoked
|
|
The residual a.k.a. data not explained by the dipoles.
|
|
Only returned if return_residual is True.
|
|
|
|
See Also
|
|
--------
|
|
mne.fit_dipole
|
|
mne.beamformer.rap_music
|
|
|
|
Notes
|
|
-----
|
|
.. versionadded:: 1.4
|
|
|
|
References
|
|
----------
|
|
.. footbibliography::
|
|
"""
|
|
return _rap_music(evoked, forward, noise_cov, n_dipoles, return_residual, True)
|