418 lines
14 KiB
Python
418 lines
14 KiB
Python
# Authors: The MNE-Python contributors.
|
|
# License: BSD-3-Clause
|
|
# Copyright the MNE-Python contributors.
|
|
|
|
import numpy as np
|
|
from numpy.polynomial.legendre import legval
|
|
from scipy.interpolate import RectBivariateSpline
|
|
from scipy.linalg import pinv
|
|
from scipy.spatial.distance import pdist, squareform
|
|
|
|
from .._fiff.meas_info import _simplify_info
|
|
from .._fiff.pick import pick_channels, pick_info, pick_types
|
|
from ..surface import _normalize_vectors
|
|
from ..utils import _validate_type, logger, verbose, warn
|
|
|
|
|
|
def _calc_h(cosang, stiffness=4, n_legendre_terms=50):
|
|
"""Calculate spherical spline h function between points on a sphere.
|
|
|
|
Parameters
|
|
----------
|
|
cosang : array-like | float
|
|
cosine of angles between pairs of points on a spherical surface. This
|
|
is equivalent to the dot product of unit vectors.
|
|
stiffness : float
|
|
stiffnes of the spline. Also referred to as ``m``.
|
|
n_legendre_terms : int
|
|
number of Legendre terms to evaluate.
|
|
"""
|
|
factors = [
|
|
(2 * n + 1) / (n ** (stiffness - 1) * (n + 1) ** (stiffness - 1) * 4 * np.pi)
|
|
for n in range(1, n_legendre_terms + 1)
|
|
]
|
|
return legval(cosang, [0] + factors)
|
|
|
|
|
|
def _calc_g(cosang, stiffness=4, n_legendre_terms=50):
|
|
"""Calculate spherical spline g function between points on a sphere.
|
|
|
|
Parameters
|
|
----------
|
|
cosang : array-like of float, shape(n_channels, n_channels)
|
|
cosine of angles between pairs of points on a spherical surface. This
|
|
is equivalent to the dot product of unit vectors.
|
|
stiffness : float
|
|
stiffness of the spline.
|
|
n_legendre_terms : int
|
|
number of Legendre terms to evaluate.
|
|
|
|
Returns
|
|
-------
|
|
G : np.ndrarray of float, shape(n_channels, n_channels)
|
|
The G matrix.
|
|
"""
|
|
factors = [
|
|
(2 * n + 1) / (n**stiffness * (n + 1) ** stiffness * 4 * np.pi)
|
|
for n in range(1, n_legendre_terms + 1)
|
|
]
|
|
return legval(cosang, [0] + factors)
|
|
|
|
|
|
def _make_interpolation_matrix(pos_from, pos_to, alpha=1e-5):
|
|
"""Compute interpolation matrix based on spherical splines.
|
|
|
|
Implementation based on [1]
|
|
|
|
Parameters
|
|
----------
|
|
pos_from : np.ndarray of float, shape(n_good_sensors, 3)
|
|
The positions to interpolate from.
|
|
pos_to : np.ndarray of float, shape(n_bad_sensors, 3)
|
|
The positions to interpolate.
|
|
alpha : float
|
|
Regularization parameter. Defaults to 1e-5.
|
|
|
|
Returns
|
|
-------
|
|
interpolation : np.ndarray of float, shape(len(pos_from), len(pos_to))
|
|
The interpolation matrix that maps good signals to the location
|
|
of bad signals.
|
|
|
|
References
|
|
----------
|
|
[1] Perrin, F., Pernier, J., Bertrand, O. and Echallier, JF. (1989).
|
|
Spherical splines for scalp potential and current density mapping.
|
|
Electroencephalography Clinical Neurophysiology, Feb; 72(2):184-7.
|
|
"""
|
|
pos_from = pos_from.copy()
|
|
pos_to = pos_to.copy()
|
|
n_from = pos_from.shape[0]
|
|
n_to = pos_to.shape[0]
|
|
|
|
# normalize sensor positions to sphere
|
|
_normalize_vectors(pos_from)
|
|
_normalize_vectors(pos_to)
|
|
|
|
# cosine angles between source positions
|
|
cosang_from = pos_from.dot(pos_from.T)
|
|
cosang_to_from = pos_to.dot(pos_from.T)
|
|
G_from = _calc_g(cosang_from)
|
|
G_to_from = _calc_g(cosang_to_from)
|
|
assert G_from.shape == (n_from, n_from)
|
|
assert G_to_from.shape == (n_to, n_from)
|
|
|
|
if alpha is not None:
|
|
G_from.flat[:: len(G_from) + 1] += alpha
|
|
|
|
C = np.vstack(
|
|
[
|
|
np.hstack([G_from, np.ones((n_from, 1))]),
|
|
np.hstack([np.ones((1, n_from)), [[0]]]),
|
|
]
|
|
)
|
|
C_inv = pinv(C)
|
|
|
|
interpolation = np.hstack([G_to_from, np.ones((n_to, 1))]) @ C_inv[:, :-1]
|
|
assert interpolation.shape == (n_to, n_from)
|
|
return interpolation
|
|
|
|
|
|
def _do_interp_dots(inst, interpolation, goods_idx, bads_idx):
|
|
"""Dot product of channel mapping matrix to channel data."""
|
|
from ..epochs import BaseEpochs
|
|
from ..evoked import Evoked
|
|
from ..io import BaseRaw
|
|
|
|
_validate_type(inst, (BaseRaw, BaseEpochs, Evoked), "inst")
|
|
inst._data[..., bads_idx, :] = np.matmul(
|
|
interpolation, inst._data[..., goods_idx, :]
|
|
)
|
|
|
|
|
|
@verbose
|
|
def _interpolate_bads_eeg(inst, origin, exclude=None, ecog=False, verbose=None):
|
|
if exclude is None:
|
|
exclude = list()
|
|
bads_idx = np.zeros(len(inst.ch_names), dtype=bool)
|
|
goods_idx = np.zeros(len(inst.ch_names), dtype=bool)
|
|
|
|
picks = pick_types(inst.info, meg=False, eeg=not ecog, ecog=ecog, exclude=exclude)
|
|
inst.info._check_consistency()
|
|
bads_idx[picks] = [inst.ch_names[ch] in inst.info["bads"] for ch in picks]
|
|
|
|
if len(picks) == 0 or bads_idx.sum() == 0:
|
|
return
|
|
|
|
goods_idx[picks] = True
|
|
goods_idx[bads_idx] = False
|
|
|
|
pos = inst._get_channel_positions(picks)
|
|
|
|
# Make sure only EEG are used
|
|
bads_idx_pos = bads_idx[picks]
|
|
goods_idx_pos = goods_idx[picks]
|
|
|
|
# test spherical fit
|
|
distance = np.linalg.norm(pos - origin, axis=-1)
|
|
distance = np.mean(distance / np.mean(distance))
|
|
if np.abs(1.0 - distance) > 0.1:
|
|
warn(
|
|
"Your spherical fit is poor, interpolation results are "
|
|
"likely to be inaccurate."
|
|
)
|
|
|
|
pos_good = pos[goods_idx_pos] - origin
|
|
pos_bad = pos[bads_idx_pos] - origin
|
|
logger.info(f"Computing interpolation matrix from {len(pos_good)} sensor positions")
|
|
interpolation = _make_interpolation_matrix(pos_good, pos_bad)
|
|
|
|
logger.info(f"Interpolating {len(pos_bad)} sensors")
|
|
_do_interp_dots(inst, interpolation, goods_idx, bads_idx)
|
|
|
|
|
|
@verbose
|
|
def _interpolate_bads_ecog(inst, origin, exclude=None, verbose=None):
|
|
_interpolate_bads_eeg(inst, origin, exclude=exclude, ecog=True, verbose=verbose)
|
|
|
|
|
|
def _interpolate_bads_meg(
|
|
inst, mode="accurate", origin=(0.0, 0.0, 0.04), verbose=None, ref_meg=False
|
|
):
|
|
return _interpolate_bads_meeg(
|
|
inst, mode, origin, ref_meg=ref_meg, eeg=False, verbose=verbose
|
|
)
|
|
|
|
|
|
@verbose
|
|
def _interpolate_bads_nan(
|
|
inst,
|
|
ch_type,
|
|
ref_meg=False,
|
|
exclude=(),
|
|
*,
|
|
verbose=None,
|
|
):
|
|
info = _simplify_info(inst.info)
|
|
picks_type = pick_types(info, ref_meg=ref_meg, exclude=exclude, **{ch_type: True})
|
|
use_ch_names = [inst.info["ch_names"][p] for p in picks_type]
|
|
bads_type = [ch for ch in inst.info["bads"] if ch in use_ch_names]
|
|
if len(bads_type) == 0 or len(picks_type) == 0:
|
|
return
|
|
# select the bad channels to be interpolated
|
|
picks_bad = pick_channels(inst.info["ch_names"], bads_type, exclude=[])
|
|
inst._data[..., picks_bad, :] = np.nan
|
|
|
|
|
|
@verbose
|
|
def _interpolate_bads_meeg(
|
|
inst,
|
|
mode="accurate",
|
|
origin=(0.0, 0.0, 0.04),
|
|
meg=True,
|
|
eeg=True,
|
|
ref_meg=False,
|
|
exclude=(),
|
|
*,
|
|
method=None,
|
|
verbose=None,
|
|
):
|
|
from ..forward import _map_meg_or_eeg_channels
|
|
|
|
if method is None:
|
|
method = {"meg": "MNE", "eeg": "MNE"}
|
|
bools = dict(meg=meg, eeg=eeg)
|
|
info = _simplify_info(inst.info)
|
|
for ch_type, do in bools.items():
|
|
if not do:
|
|
continue
|
|
kw = dict(meg=False, eeg=False)
|
|
kw[ch_type] = True
|
|
picks_type = pick_types(info, ref_meg=ref_meg, exclude=exclude, **kw)
|
|
picks_good = pick_types(info, ref_meg=ref_meg, exclude="bads", **kw)
|
|
use_ch_names = [inst.info["ch_names"][p] for p in picks_type]
|
|
bads_type = [ch for ch in inst.info["bads"] if ch in use_ch_names]
|
|
if len(bads_type) == 0 or len(picks_type) == 0:
|
|
continue
|
|
# select the bad channels to be interpolated
|
|
picks_bad = pick_channels(inst.info["ch_names"], bads_type, exclude=[])
|
|
|
|
# do MNE based interpolation
|
|
if ch_type == "eeg":
|
|
picks_to = picks_type
|
|
bad_sel = np.isin(picks_type, picks_bad)
|
|
else:
|
|
picks_to = picks_bad
|
|
bad_sel = slice(None)
|
|
info_from = pick_info(inst.info, picks_good)
|
|
info_to = pick_info(inst.info, picks_to)
|
|
mapping = _map_meg_or_eeg_channels(info_from, info_to, mode=mode, origin=origin)
|
|
mapping = mapping[bad_sel]
|
|
_do_interp_dots(inst, mapping, picks_good, picks_bad)
|
|
|
|
|
|
@verbose
|
|
def _interpolate_bads_nirs(inst, exclude=(), verbose=None):
|
|
from mne.preprocessing.nirs import _validate_nirs_info
|
|
|
|
if len(pick_types(inst.info, fnirs=True, exclude=())) == 0:
|
|
return
|
|
|
|
# Returns pick of all nirs and ensures channels are correctly ordered
|
|
picks_nirs = _validate_nirs_info(inst.info)
|
|
nirs_ch_names = [inst.info["ch_names"][p] for p in picks_nirs]
|
|
nirs_ch_names = [ch for ch in nirs_ch_names if ch not in exclude]
|
|
bads_nirs = [ch for ch in inst.info["bads"] if ch in nirs_ch_names]
|
|
if len(bads_nirs) == 0:
|
|
return
|
|
picks_bad = pick_channels(inst.info["ch_names"], bads_nirs, exclude=[])
|
|
bads_mask = [p in picks_bad for p in picks_nirs]
|
|
|
|
chs = [inst.info["chs"][i] for i in picks_nirs]
|
|
locs3d = np.array([ch["loc"][:3] for ch in chs])
|
|
|
|
dist = pdist(locs3d)
|
|
dist = squareform(dist)
|
|
|
|
for bad in picks_bad:
|
|
dists_to_bad = dist[bad]
|
|
# Ignore distances to self
|
|
dists_to_bad[dists_to_bad == 0] = np.inf
|
|
# Ignore distances to other bad channels
|
|
dists_to_bad[bads_mask] = np.inf
|
|
# Find closest remaining channels for same frequency
|
|
closest_idx = np.argmin(dists_to_bad) + (bad % 2)
|
|
inst._data[bad] = inst._data[closest_idx]
|
|
|
|
# TODO: this seems like a bug because it does not respect reset_bads
|
|
inst.info["bads"] = [ch for ch in inst.info["bads"] if ch in exclude]
|
|
|
|
return inst
|
|
|
|
|
|
def _find_seeg_electrode_shaft(pos, tol_shaft=0.002, tol_spacing=1):
|
|
# 1) find nearest neighbor to define the electrode shaft line
|
|
# 2) find all contacts on the same line
|
|
# 3) remove contacts with large distances
|
|
|
|
dist = squareform(pdist(pos))
|
|
np.fill_diagonal(dist, np.inf)
|
|
|
|
shafts = list()
|
|
shaft_ts = list()
|
|
for i, n1 in enumerate(pos):
|
|
if any([i in shaft for shaft in shafts]):
|
|
continue
|
|
n2 = pos[np.argmin(dist[i])] # 1
|
|
# https://mathworld.wolfram.com/Point-LineDistance3-Dimensional.html
|
|
shaft_dists = np.linalg.norm(
|
|
np.cross((pos - n1), (pos - n2)), axis=1
|
|
) / np.linalg.norm(n2 - n1)
|
|
shaft = np.where(shaft_dists < tol_shaft)[0] # 2
|
|
shaft_prev = None
|
|
for _ in range(10): # avoid potential cycles
|
|
if np.array_equal(shaft, shaft_prev):
|
|
break
|
|
shaft_prev = shaft
|
|
# compute median shaft line
|
|
v = np.median(
|
|
[
|
|
pos[i] - pos[j]
|
|
for idx, i in enumerate(shaft)
|
|
for j in shaft[idx + 1 :]
|
|
],
|
|
axis=0,
|
|
)
|
|
c = np.median(pos[shaft], axis=0)
|
|
# recompute distances
|
|
shaft_dists = np.linalg.norm(
|
|
np.cross((pos - c), (pos - c + v)), axis=1
|
|
) / np.linalg.norm(v)
|
|
shaft = np.where(shaft_dists < tol_shaft)[0]
|
|
ts = np.array([np.dot(c - n0, v) / np.linalg.norm(v) ** 2 for n0 in pos[shaft]])
|
|
shaft_order = np.argsort(ts)
|
|
shaft = shaft[shaft_order]
|
|
ts = ts[shaft_order]
|
|
|
|
# only include the largest group with spacing with the error tolerance
|
|
# avoid interpolating across spans between contacts
|
|
t_diffs = np.diff(ts)
|
|
t_diff_med = np.median(t_diffs)
|
|
spacing_errors = (t_diffs - t_diff_med) / t_diff_med
|
|
groups = list()
|
|
group = [shaft[0]]
|
|
for j in range(len(shaft) - 1):
|
|
if spacing_errors[j] > tol_spacing:
|
|
groups.append(group)
|
|
group = [shaft[j + 1]]
|
|
else:
|
|
group.append(shaft[j + 1])
|
|
groups.append(group)
|
|
group = [group for group in groups if i in group][0]
|
|
ts = ts[np.isin(shaft, group)]
|
|
shaft = np.array(group, dtype=int)
|
|
|
|
shafts.append(shaft)
|
|
shaft_ts.append(ts)
|
|
return shafts, shaft_ts
|
|
|
|
|
|
@verbose
|
|
def _interpolate_bads_seeg(
|
|
inst, exclude=None, tol_shaft=0.002, tol_spacing=1, verbose=None
|
|
):
|
|
if exclude is None:
|
|
exclude = list()
|
|
picks = pick_types(inst.info, meg=False, seeg=True, exclude=exclude)
|
|
inst.info._check_consistency()
|
|
bads_idx = np.isin(np.array(inst.ch_names)[picks], inst.info["bads"])
|
|
|
|
if len(picks) == 0 or bads_idx.sum() == 0:
|
|
return
|
|
|
|
pos = inst._get_channel_positions(picks)
|
|
|
|
# Make sure only sEEG are used
|
|
bads_idx_pos = bads_idx[picks]
|
|
|
|
shafts, shaft_ts = _find_seeg_electrode_shaft(
|
|
pos, tol_shaft=tol_shaft, tol_spacing=tol_spacing
|
|
)
|
|
|
|
# interpolate the bad contacts
|
|
picks_bad = list(np.where(bads_idx_pos)[0])
|
|
for shaft, ts in zip(shafts, shaft_ts):
|
|
bads_shaft = np.array([idx for idx in picks_bad if idx in shaft])
|
|
if bads_shaft.size == 0:
|
|
continue
|
|
goods_shaft = shaft[np.isin(shaft, bads_shaft, invert=True)]
|
|
if goods_shaft.size < 4: # cubic spline requires 3 channels
|
|
msg = "No shaft" if shaft.size < 4 else "Not enough good channels"
|
|
no_shaft_chs = " and ".join(np.array(inst.ch_names)[bads_shaft])
|
|
raise RuntimeError(
|
|
f"{msg} found in a line with {no_shaft_chs} "
|
|
"at least 3 good channels on the same line "
|
|
f"are required for interpolation, {goods_shaft.size} found. "
|
|
f"Dropping {no_shaft_chs} is recommended."
|
|
)
|
|
logger.debug(
|
|
f"Interpolating {np.array(inst.ch_names)[bads_shaft]} using "
|
|
f"data from {np.array(inst.ch_names)[goods_shaft]}"
|
|
)
|
|
bads_shaft_idx = np.where(np.isin(shaft, bads_shaft))[0]
|
|
goods_shaft_idx = np.where(~np.isin(shaft, bads_shaft))[0]
|
|
|
|
z = inst._data[..., goods_shaft, :]
|
|
is_epochs = z.ndim == 3
|
|
if is_epochs:
|
|
z = z.swapaxes(0, 1)
|
|
z = z.reshape(z.shape[0], -1)
|
|
y = np.arange(z.shape[-1])
|
|
out = RectBivariateSpline(x=ts[goods_shaft_idx], y=y, z=z)(
|
|
x=ts[bads_shaft_idx], y=y
|
|
)
|
|
if is_epochs:
|
|
out = out.reshape(bads_shaft.size, inst._data.shape[0], -1)
|
|
out = out.swapaxes(0, 1)
|
|
inst._data[..., bads_shaft, :] = out
|