360 lines
12 KiB
Python
360 lines
12 KiB
Python
"""Testing functions."""
|
|
|
|
# Authors: The MNE-Python contributors.
|
|
# License: BSD-3-Clause
|
|
# Copyright the MNE-Python contributors.
|
|
|
|
import inspect
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import traceback
|
|
from functools import wraps
|
|
from shutil import rmtree
|
|
from unittest import SkipTest
|
|
|
|
import numpy as np
|
|
from numpy.testing import assert_allclose, assert_array_equal
|
|
from scipy import linalg
|
|
|
|
from ._logging import ClosingStringIO, warn
|
|
from .check import check_version
|
|
from .misc import run_subprocess
|
|
from .numerics import object_diff
|
|
|
|
|
|
def _explain_exception(start=-1, stop=None, prefix="> "):
|
|
"""Explain an exception."""
|
|
# start=-1 means "only the most recent caller"
|
|
etype, value, tb = sys.exc_info()
|
|
string = traceback.format_list(traceback.extract_tb(tb)[start:stop])
|
|
string = "".join(string).split("\n") + traceback.format_exception_only(etype, value)
|
|
string = ":\n" + prefix + ("\n" + prefix).join(string)
|
|
return string
|
|
|
|
|
|
class _TempDir(str):
|
|
"""Create and auto-destroy temp dir.
|
|
|
|
This is designed to be used with testing modules. Instances should be
|
|
defined inside test functions. Instances defined at module level can not
|
|
guarantee proper destruction of the temporary directory.
|
|
|
|
When used at module level, the current use of the __del__() method for
|
|
cleanup can fail because the rmtree function may be cleaned up before this
|
|
object (an alternative could be using the atexit module instead).
|
|
"""
|
|
|
|
def __new__(self): # noqa: D105
|
|
new = str.__new__(self, tempfile.mkdtemp(prefix="tmp_mne_tempdir_"))
|
|
return new
|
|
|
|
def __init__(self):
|
|
self._path = self.__str__()
|
|
|
|
def __del__(self): # noqa: D105
|
|
rmtree(self._path, ignore_errors=True)
|
|
|
|
|
|
def requires_mne(func):
|
|
"""Decorate a function as requiring MNE."""
|
|
return requires_mne_mark()(func)
|
|
|
|
|
|
def requires_mne_mark():
|
|
"""Mark pytest tests that require MNE-C."""
|
|
import pytest
|
|
|
|
return pytest.mark.skipif(not has_mne_c(), reason="Requires MNE-C")
|
|
|
|
|
|
def requires_openmeeg_mark():
|
|
"""Mark pytest tests that require OpenMEEG."""
|
|
import pytest
|
|
|
|
return pytest.mark.skipif(
|
|
not check_version("openmeeg", "2.5.6"), reason="Requires OpenMEEG >= 2.5.6"
|
|
)
|
|
|
|
|
|
def requires_freesurfer(arg):
|
|
"""Require Freesurfer."""
|
|
import pytest
|
|
|
|
reason = "Requires Freesurfer"
|
|
if isinstance(arg, str):
|
|
# Calling as @requires_freesurfer('progname'): return decorator
|
|
# after checking for progname existence
|
|
reason += f" command: {arg}"
|
|
try:
|
|
run_subprocess([arg, "--version"])
|
|
except Exception:
|
|
skip = True
|
|
else:
|
|
skip = False
|
|
return pytest.mark.skipif(skip, reason=reason)
|
|
else:
|
|
# Calling directly as @requires_freesurfer: return decorated function
|
|
# and just check env var existence
|
|
return pytest.mark.skipif(not has_freesurfer(), reason="Requires Freesurfer")(
|
|
arg
|
|
)
|
|
|
|
|
|
def requires_good_network(func):
|
|
import pytest
|
|
|
|
return pytest.mark.skipif(
|
|
int(os.environ.get("MNE_SKIP_NETWORK_TESTS", 0)),
|
|
reason="MNE_SKIP_NETWORK_TESTS is set",
|
|
)(func)
|
|
|
|
|
|
def run_command_if_main():
|
|
"""Run a given command if it's __main__."""
|
|
local_vars = inspect.currentframe().f_back.f_locals
|
|
if local_vars.get("__name__", "") == "__main__":
|
|
local_vars["run"]()
|
|
|
|
|
|
class ArgvSetter:
|
|
"""Temporarily set sys.argv."""
|
|
|
|
def __init__(self, args=(), disable_stdout=True, disable_stderr=True):
|
|
self.argv = list(("python",) + args)
|
|
self.stdout = ClosingStringIO() if disable_stdout else sys.stdout
|
|
self.stderr = ClosingStringIO() if disable_stderr else sys.stderr
|
|
|
|
def __enter__(self): # noqa: D105
|
|
self.orig_argv = sys.argv
|
|
sys.argv = self.argv
|
|
self.orig_stdout = sys.stdout
|
|
sys.stdout = self.stdout
|
|
self.orig_stderr = sys.stderr
|
|
sys.stderr = self.stderr
|
|
return self
|
|
|
|
def __exit__(self, *args): # noqa: D105
|
|
sys.argv = self.orig_argv
|
|
sys.stdout = self.orig_stdout
|
|
sys.stderr = self.orig_stderr
|
|
|
|
|
|
def has_mne_c():
|
|
"""Check for MNE-C."""
|
|
return "MNE_ROOT" in os.environ
|
|
|
|
|
|
def has_freesurfer():
|
|
"""Check for Freesurfer."""
|
|
return "FREESURFER_HOME" in os.environ
|
|
|
|
|
|
def buggy_mkl_svd(function):
|
|
"""Decorate tests that make calls to SVD and intermittently fail."""
|
|
|
|
@wraps(function)
|
|
def dec(*args, **kwargs):
|
|
try:
|
|
return function(*args, **kwargs)
|
|
except np.linalg.LinAlgError as exp:
|
|
if "SVD did not converge" in str(exp):
|
|
msg = "Intel MKL SVD convergence error detected, skipping test"
|
|
warn(msg)
|
|
raise SkipTest(msg)
|
|
raise
|
|
|
|
return dec
|
|
|
|
|
|
def assert_and_remove_boundary_annot(annotations, n=1):
|
|
"""Assert that there are boundary annotations and remove them."""
|
|
from ..io import BaseRaw
|
|
|
|
if isinstance(annotations, BaseRaw): # allow either input
|
|
annotations = annotations.annotations
|
|
for key in ("EDGE", "BAD"):
|
|
idx = np.where(annotations.description == f"{key} boundary")[0]
|
|
assert len(idx) == n
|
|
annotations.delete(idx)
|
|
|
|
|
|
def assert_object_equal(a, b):
|
|
"""Assert two objects are equal."""
|
|
d = object_diff(a, b)
|
|
assert d == "", d
|
|
|
|
|
|
def _raw_annot(meas_date, orig_time):
|
|
from .._fiff.meas_info import create_info
|
|
from ..annotations import Annotations, _handle_meas_date
|
|
from ..io import RawArray
|
|
|
|
info = create_info(ch_names=10, sfreq=10.0)
|
|
raw = RawArray(data=np.empty((10, 10)), info=info, first_samp=10)
|
|
if meas_date is not None:
|
|
meas_date = _handle_meas_date(meas_date)
|
|
with raw.info._unlock(check_after=True):
|
|
raw.info["meas_date"] = meas_date
|
|
annot = Annotations([0.5], [0.2], ["dummy"], orig_time)
|
|
raw.set_annotations(annotations=annot)
|
|
return raw
|
|
|
|
|
|
def _get_data(x, ch_idx):
|
|
"""Get the (n_ch, n_times) data array."""
|
|
from ..evoked import Evoked
|
|
from ..io import BaseRaw
|
|
|
|
if isinstance(x, BaseRaw):
|
|
return x[ch_idx][0]
|
|
elif isinstance(x, Evoked):
|
|
return x.data[ch_idx]
|
|
|
|
|
|
def _check_snr(actual, desired, picks, min_tol, med_tol, msg, kind="MEG"):
|
|
"""Check the SNR of a set of channels."""
|
|
actual_data = _get_data(actual, picks)
|
|
desired_data = _get_data(desired, picks)
|
|
bench_rms = np.sqrt(np.mean(desired_data * desired_data, axis=1))
|
|
error = actual_data - desired_data
|
|
error_rms = np.sqrt(np.mean(error * error, axis=1))
|
|
np.clip(error_rms, 1e-60, np.inf, out=error_rms) # avoid division by zero
|
|
snrs = bench_rms / error_rms
|
|
# min tol
|
|
snr = snrs.min()
|
|
bad_count = (snrs < min_tol).sum()
|
|
msg = f" ({msg})" if msg != "" else msg
|
|
assert bad_count == 0, (
|
|
f"SNR (worst {snr:0.2f}) < {min_tol:0.2f} "
|
|
f"for {bad_count}/{len(picks)} channels{msg}"
|
|
)
|
|
# median tol
|
|
snr = np.median(snrs)
|
|
assert snr >= med_tol, f"{kind} SNR median {snr:0.2f} < {med_tol:0.2f}{msg}"
|
|
|
|
|
|
def assert_meg_snr(
|
|
actual, desired, min_tol, med_tol=500.0, chpi_med_tol=500.0, msg=None
|
|
):
|
|
"""Assert channel SNR of a certain level.
|
|
|
|
Mostly useful for operations like Maxwell filtering that modify
|
|
MEG channels while leaving EEG and others intact.
|
|
"""
|
|
from .._fiff.pick import pick_types
|
|
|
|
picks = pick_types(desired.info, meg=True, exclude=[])
|
|
picks_desired = pick_types(desired.info, meg=True, exclude=[])
|
|
assert_array_equal(picks, picks_desired, err_msg="MEG pick mismatch")
|
|
chpis = pick_types(actual.info, meg=False, chpi=True, exclude=[])
|
|
chpis_desired = pick_types(desired.info, meg=False, chpi=True, exclude=[])
|
|
if chpi_med_tol is not None:
|
|
assert_array_equal(chpis, chpis_desired, err_msg="cHPI pick mismatch")
|
|
others = np.setdiff1d(
|
|
np.arange(len(actual.ch_names)), np.concatenate([picks, chpis])
|
|
)
|
|
others_desired = np.setdiff1d(
|
|
np.arange(len(desired.ch_names)), np.concatenate([picks_desired, chpis_desired])
|
|
)
|
|
assert_array_equal(others, others_desired, err_msg="Other pick mismatch")
|
|
if len(others) > 0: # if non-MEG channels present
|
|
assert_allclose(
|
|
_get_data(actual, others),
|
|
_get_data(desired, others),
|
|
atol=1e-11,
|
|
rtol=1e-5,
|
|
err_msg="non-MEG channel mismatch",
|
|
)
|
|
_check_snr(actual, desired, picks, min_tol, med_tol, msg, kind="MEG")
|
|
if chpi_med_tol is not None and len(chpis) > 0:
|
|
_check_snr(actual, desired, chpis, 0.0, chpi_med_tol, msg, kind="cHPI")
|
|
|
|
|
|
def assert_snr(actual, desired, tol):
|
|
"""Assert actual and desired arrays are within some SNR tolerance."""
|
|
with np.errstate(divide="ignore"): # allow infinite
|
|
snr = linalg.norm(desired, ord="fro") / linalg.norm(desired - actual, ord="fro")
|
|
assert snr >= tol, f"{snr} < {tol}"
|
|
|
|
|
|
def assert_stcs_equal(stc1, stc2):
|
|
"""Check that two STC are equal."""
|
|
assert_allclose(stc1.times, stc2.times)
|
|
assert_allclose(stc1.data, stc2.data)
|
|
assert_array_equal(stc1.vertices[0], stc2.vertices[0])
|
|
assert_array_equal(stc1.vertices[1], stc2.vertices[1])
|
|
assert_allclose(stc1.tmin, stc2.tmin)
|
|
assert_allclose(stc1.tstep, stc2.tstep)
|
|
|
|
|
|
def _dig_sort_key(dig):
|
|
"""Sort dig keys."""
|
|
return (dig["kind"], dig["ident"])
|
|
|
|
|
|
def assert_dig_allclose(info_py, info_bin, limit=None):
|
|
"""Assert dig allclose."""
|
|
from .._fiff.constants import FIFF
|
|
from .._fiff.meas_info import Info
|
|
from ..bem import fit_sphere_to_headshape
|
|
from ..channels.montage import DigMontage
|
|
|
|
# test dig positions
|
|
dig_py, dig_bin = info_py, info_bin
|
|
if isinstance(dig_py, Info):
|
|
assert isinstance(dig_bin, Info)
|
|
dig_py, dig_bin = dig_py["dig"], dig_bin["dig"]
|
|
else:
|
|
assert isinstance(dig_bin, DigMontage)
|
|
assert isinstance(dig_py, DigMontage)
|
|
dig_py, dig_bin = dig_py.dig, dig_bin.dig
|
|
info_py = info_bin = None
|
|
assert isinstance(dig_py, list)
|
|
assert isinstance(dig_bin, list)
|
|
dig_py = sorted(dig_py, key=_dig_sort_key)
|
|
dig_bin = sorted(dig_bin, key=_dig_sort_key)
|
|
assert len(dig_py) == len(dig_bin)
|
|
for ii, (d_py, d_bin) in enumerate(zip(dig_py[:limit], dig_bin[:limit])):
|
|
for key in ("ident", "kind", "coord_frame"):
|
|
assert d_py[key] == d_bin[key], key
|
|
assert_allclose(
|
|
d_py["r"],
|
|
d_bin["r"],
|
|
rtol=1e-5,
|
|
atol=1e-5,
|
|
err_msg=f"Failure on {ii}:\n{d_py['r']}\n{d_bin['r']}",
|
|
)
|
|
if any(d["kind"] == FIFF.FIFFV_POINT_EXTRA for d in dig_py) and info_py is not None:
|
|
r_bin, o_head_bin, o_dev_bin = fit_sphere_to_headshape(
|
|
info_bin, units="m", verbose="error"
|
|
)
|
|
r_py, o_head_py, o_dev_py = fit_sphere_to_headshape(
|
|
info_py, units="m", verbose="error"
|
|
)
|
|
assert_allclose(r_py, r_bin, atol=1e-6)
|
|
assert_allclose(o_dev_py, o_dev_bin, rtol=1e-5, atol=1e-6)
|
|
assert_allclose(o_head_py, o_head_bin, rtol=1e-5, atol=1e-6)
|
|
|
|
|
|
def _click_ch_name(fig, ch_index=0, button=1):
|
|
"""Click on a channel name in a raw/epochs/ICA browse-style plot."""
|
|
from ..viz.utils import _fake_click
|
|
|
|
fig.canvas.draw()
|
|
text = fig.mne.ax_main.get_yticklabels()[ch_index]
|
|
bbox = text.get_window_extent()
|
|
x = bbox.intervalx.mean()
|
|
y = bbox.intervaly.mean()
|
|
_fake_click(fig, fig.mne.ax_main, (x, y), xform="pix", button=button)
|
|
|
|
|
|
def _get_suptitle(fig):
|
|
"""Get fig suptitle (shim for matplotlib < 3.8.0)."""
|
|
# TODO: obsolete when minimum MPL version is 3.8
|
|
if check_version("matplotlib", "3.8"):
|
|
return fig.get_suptitle()
|
|
else:
|
|
# unreliable hack; should work in most tests as we rarely use `sup_{x,y}label`
|
|
return fig.texts[0].get_text()
|