232 lines
8.4 KiB
Python
232 lines
8.4 KiB
Python
"""Tools for data interpolation."""
|
|
|
|
# Authors: The MNE-Python contributors.
|
|
# License: BSD-3-Clause
|
|
# Copyright the MNE-Python contributors.
|
|
|
|
from itertools import chain
|
|
|
|
import numpy as np
|
|
from scipy.sparse.csgraph import connected_components
|
|
|
|
from .._fiff.meas_info import create_info
|
|
from ..epochs import BaseEpochs, EpochsArray
|
|
from ..evoked import Evoked, EvokedArray
|
|
from ..io import BaseRaw, RawArray
|
|
from ..transforms import _cart_to_sph, _sph_to_cart
|
|
from ..utils import _ensure_int, _validate_type
|
|
|
|
|
|
def equalize_bads(insts, interp_thresh=1.0, copy=True):
|
|
"""Interpolate or mark bads consistently for a list of instances.
|
|
|
|
Once called on a list of instances, the instances can be concatenated
|
|
as they will have the same list of bad channels.
|
|
|
|
Parameters
|
|
----------
|
|
insts : list
|
|
The list of instances (Evoked, Epochs or Raw) to consider
|
|
for interpolation. Each instance should have marked channels.
|
|
interp_thresh : float
|
|
A float between 0 and 1 (default) that specifies the fraction of time
|
|
a channel should be good to be eventually interpolated for certain
|
|
instances. For example if 0.5, a channel which is good at least half
|
|
of the time will be interpolated in the instances where it is marked
|
|
as bad. If 1 then channels will never be interpolated and if 0 all bad
|
|
channels will be systematically interpolated.
|
|
copy : bool
|
|
If True then the returned instances will be copies.
|
|
|
|
Returns
|
|
-------
|
|
insts_bads : list
|
|
The list of instances, with the same channel(s) marked as bad in all of
|
|
them, possibly with some formerly bad channels interpolated.
|
|
"""
|
|
if not 0 <= interp_thresh <= 1:
|
|
raise ValueError(f"interp_thresh must be between 0 and 1, got {interp_thresh}")
|
|
|
|
all_bads = list(set(chain.from_iterable([inst.info["bads"] for inst in insts])))
|
|
if isinstance(insts[0], BaseEpochs):
|
|
durations = [len(inst) * len(inst.times) for inst in insts]
|
|
else:
|
|
durations = [len(inst.times) for inst in insts]
|
|
|
|
good_times = []
|
|
for ch_name in all_bads:
|
|
good_times.append(
|
|
sum(
|
|
durations[k]
|
|
for k, inst in enumerate(insts)
|
|
if ch_name not in inst.info["bads"]
|
|
)
|
|
/ np.sum(durations)
|
|
)
|
|
|
|
bads_keep = [ch for k, ch in enumerate(all_bads) if good_times[k] < interp_thresh]
|
|
if copy:
|
|
insts = [inst.copy() for inst in insts]
|
|
|
|
for inst in insts:
|
|
if len(set(inst.info["bads"]) - set(bads_keep)):
|
|
inst.interpolate_bads(exclude=bads_keep)
|
|
inst.info["bads"] = bads_keep
|
|
|
|
return insts
|
|
|
|
|
|
def interpolate_bridged_electrodes(inst, bridged_idx, bad_limit=4):
|
|
"""Interpolate bridged electrode pairs.
|
|
|
|
Because bridged electrodes contain brain signal, it's just that the
|
|
signal is spatially smeared between the two electrodes, we can
|
|
make a virtual channel midway between the bridged pairs and use
|
|
that to aid in interpolation rather than completely discarding the
|
|
data from the two channels.
|
|
|
|
Parameters
|
|
----------
|
|
inst : instance of Epochs, Evoked, or Raw
|
|
The data object with channels that are to be interpolated.
|
|
bridged_idx : list of tuple
|
|
The indices of channels marked as bridged with each bridged
|
|
pair stored as a tuple.
|
|
bad_limit : int
|
|
The maximum number of electrodes that can be bridged together
|
|
(included) and interpolated. Above this number, an error will be
|
|
raised.
|
|
|
|
.. versionadded:: 1.2
|
|
|
|
Returns
|
|
-------
|
|
inst : instance of Epochs, Evoked, or Raw
|
|
The modified data object.
|
|
|
|
See Also
|
|
--------
|
|
mne.preprocessing.compute_bridged_electrodes
|
|
"""
|
|
_validate_type(inst, (BaseRaw, BaseEpochs, Evoked))
|
|
bad_limit = _ensure_int(bad_limit, "bad_limit")
|
|
if bad_limit <= 0:
|
|
raise ValueError(
|
|
"Argument 'bad_limit' should be a strictly positive "
|
|
f"integer. Provided {bad_limit} is invalid."
|
|
)
|
|
montage = inst.get_montage()
|
|
if montage is None:
|
|
raise RuntimeError("No channel positions found in ``inst``")
|
|
pos = montage.get_positions()
|
|
if pos["coord_frame"] != "head":
|
|
raise RuntimeError(
|
|
f"Montage channel positions must be in ``head`` got {pos['coord_frame']}"
|
|
)
|
|
# store bads orig to put back at the end
|
|
bads_orig = inst.info["bads"]
|
|
inst.info["bads"] = list()
|
|
|
|
# look for group of bad channels
|
|
nodes = sorted(set(chain(*bridged_idx)))
|
|
G_dense = np.zeros((len(nodes), len(nodes)))
|
|
# fill the edges with a weight of 1
|
|
for bridge in bridged_idx:
|
|
idx0 = np.searchsorted(nodes, bridge[0])
|
|
idx1 = np.searchsorted(nodes, bridge[1])
|
|
G_dense[idx0, idx1] = 1
|
|
G_dense[idx1, idx0] = 1
|
|
# look for connected components
|
|
_, labels = connected_components(G_dense, directed=False)
|
|
groups_idx = [[nodes[j] for j in np.where(labels == k)[0]] for k in set(labels)]
|
|
groups_names = [
|
|
[inst.info.ch_names[k] for k in group_idx] for group_idx in groups_idx
|
|
]
|
|
|
|
# warn for all bridged areas that include too many electrodes
|
|
for group_names in groups_names:
|
|
if len(group_names) > bad_limit:
|
|
raise RuntimeError(
|
|
f"The channels {', '.join(group_names)} are bridged together "
|
|
"and form a large area of bridged electrodes. Interpolation "
|
|
"might be inaccurate."
|
|
)
|
|
|
|
# make virtual channels
|
|
virtual_chs = dict()
|
|
bads = set()
|
|
for k, group_idx in enumerate(groups_idx):
|
|
group_names = [inst.info.ch_names[k] for k in group_idx]
|
|
bads = bads.union(group_names)
|
|
# compute centroid position in spherical "head" coordinates
|
|
pos_virtual = _find_centroid_sphere(pos["ch_pos"], group_names)
|
|
# create the virtual channel info and set the position
|
|
virtual_info = create_info([f"virtual {k + 1}"], inst.info["sfreq"], "eeg")
|
|
virtual_info["chs"][0]["loc"][:3] = pos_virtual
|
|
# create virtual channel
|
|
data = inst.get_data(picks=group_names)
|
|
if isinstance(inst, BaseRaw):
|
|
data = np.average(data, axis=0).reshape(1, -1)
|
|
virtual_ch = RawArray(data, virtual_info, first_samp=inst.first_samp)
|
|
elif isinstance(inst, BaseEpochs):
|
|
data = np.average(data, axis=1).reshape(len(data), 1, -1)
|
|
virtual_ch = EpochsArray(data, virtual_info, tmin=inst.tmin)
|
|
else: # evoked
|
|
data = np.average(data, axis=0).reshape(1, -1)
|
|
virtual_ch = EvokedArray(
|
|
np.average(data, axis=0).reshape(1, -1),
|
|
virtual_info,
|
|
tmin=inst.tmin,
|
|
nave=inst.nave,
|
|
kind=inst.kind,
|
|
)
|
|
virtual_chs[f"virtual {k + 1}"] = virtual_ch
|
|
|
|
# add the virtual channels
|
|
inst.add_channels(list(virtual_chs.values()), force_update_info=True)
|
|
|
|
# use the virtual channels to interpolate
|
|
inst.info["bads"] = list(bads)
|
|
inst.interpolate_bads()
|
|
|
|
# drop virtual channels
|
|
inst.drop_channels(list(virtual_chs.keys()))
|
|
|
|
inst.info["bads"] = bads_orig
|
|
return inst
|
|
|
|
|
|
def _find_centroid_sphere(ch_pos, group_names):
|
|
"""Compute the centroid position between N electrodes.
|
|
|
|
The centroid should be determined in spherical "head" coordinates which is
|
|
more accurante than cutting through the scalp by averaging in cartesian
|
|
coordinates.
|
|
|
|
A simple way is to average the location in cartesian coordinate, convert
|
|
to spehrical coordinate and replace the radius with the average radius of
|
|
the N points in spherical coordinates.
|
|
|
|
Parameters
|
|
----------
|
|
ch_pos : OrderedDict
|
|
The position of all channels in cartesian coordinates.
|
|
group_names : list | tuple
|
|
The name of the N electrodes used to determine the centroid.
|
|
|
|
Returns
|
|
-------
|
|
pos_centroid : array of shape (3,)
|
|
The position of the centroid in cartesian coordinates.
|
|
"""
|
|
cartesian_positions = np.array([ch_pos[ch_name] for ch_name in group_names])
|
|
sphere_positions = _cart_to_sph(cartesian_positions)
|
|
cartesian_pos_centroid = np.average(cartesian_positions, axis=0)
|
|
sphere_pos_centroid = _cart_to_sph(cartesian_pos_centroid)
|
|
# average the radius and overwrite it
|
|
avg_radius = np.average(sphere_positions, axis=0)[0]
|
|
sphere_pos_centroid[0, 0] = avg_radius
|
|
# convert back to cartesian
|
|
pos_centroid = _sph_to_cart(sphere_pos_centroid)[0, :]
|
|
return pos_centroid
|