"""The check functions.""" # Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. import numbers import operator import os import re from builtins import input # noqa: UP029 from difflib import get_close_matches from importlib import import_module from inspect import signature from pathlib import Path import numpy as np from ..defaults import HEAD_SIZE_DEFAULT, _handle_default from ..fixes import _compare_version, _median_complex from ._logging import _record_warnings, _verbose_safe_false, logger, verbose, warn def _ensure_int(x, name="unknown", must_be="an int", *, extra=""): """Ensure a variable is an integer.""" # This is preferred over numbers.Integral, see: # https://github.com/scipy/scipy/pull/7351#issuecomment-299713159 extra = f" {extra}" if extra else extra try: # someone passing True/False is much more likely to be an error than # intentional usage if isinstance(x, bool): raise TypeError() x = int(operator.index(x)) except TypeError: raise TypeError(f"{name} must be {must_be}{extra}, got {type(x)}") return x def _check_integer_or_list(arg, name): """Validate arguments that should be an integer or a list. Always returns a list. """ if not isinstance(arg, list): arg = [_ensure_int(arg, name=name, must_be="an integer or a list")] return arg def check_fname(fname, filetype, endings, endings_err=()): """Enforce MNE filename conventions. Parameters ---------- fname : str Name of the file. filetype : str Type of file. e.g., ICA, Epochs etc. endings : tuple Acceptable endings for the filename. endings_err : tuple Obligatory possible endings for the filename. """ _validate_type(fname, "path-like", "fname") fname = str(fname) if len(endings_err) > 0 and not fname.endswith(endings_err): print_endings = " or ".join([", ".join(endings_err[:-1]), endings_err[-1]]) raise OSError( f"The filename ({fname}) for file type {filetype} must end " f"with {print_endings}" ) print_endings = " or ".join([", ".join(endings[:-1]), endings[-1]]) if not fname.endswith(endings): warn( f"This filename ({fname}) does not conform to MNE naming conventions. " f"All {filetype} files should end with {print_endings}" ) def check_version(library, min_version="0.0", *, strip=True, return_version=False): r"""Check minimum library version required. Parameters ---------- library : str The library name to import. Must have a ``__version__`` property. min_version : str The minimum version string. Anything that matches ``'(\d+ | [a-z]+ | \.)'``. Can also be empty to skip version check (just check for library presence). strip : bool If True (default), then PEP440 development markers like ``.devN`` will be stripped from the version. This makes it so that ``check_version('mne', '1.1')`` will be ``True`` even when on version ``'1.1.dev0'`` (prerelease/dev version). This option is provided for backward compatibility with the behavior of ``LooseVersion``, and diverges from how modern parsing in ``packaging.version.parse`` works. .. versionadded:: 1.0 return_version : bool If True (default False), also return the version (can be None if the library is missing). .. versionadded:: 1.0 Returns ------- ok : bool True if the library exists with at least the specified version. version : str | None The version. Only returned when ``return_version=True``. """ ok = True version = None try: library = import_module(library) except ImportError: ok = False else: check_version = min_version and min_version != "0.0" get_version = check_version or return_version if get_version: version = library.__version__ if strip: version = _strip_dev(version) if check_version: if _compare_version(version, "<", min_version): ok = False out = (ok, version) if return_version else ok return out def _strip_dev(version): # First capturing group () is what we want to keep, at the beginning: # # - at least one numeral, then # - repeats of {dot, at least one numeral} # # The rest (consume to the end of the string) is the stuff we want to cut # off: # # - A period (maybe), then # - "dev", "rc", or "+", then # - numerals, periods, dashes, and "a" through "g" (hex chars) # # Thanks https://www.regextester.com ! exp = r"^([0-9]+(?:\.[0-9]+)*)\.?(?:dev|rc|\+)[0-9+a-g\.\-]+$" match = re.match(exp, version) return match.groups()[0] if match is not None else version def _require_version(lib, what, version="0.0"): """Require library for a purpose.""" ok, got = check_version(lib, version, return_version=True) if not ok: extra = f" (version >= {version})" if version != "0.0" else "" why = "package was not found" if got is None else f"got {repr(got)}" raise ImportError(f"The {lib} package{extra} is required to {what}, {why}") def _import_h5py(): _require_version("h5py", "read MATLAB files >= v7.3") import h5py return h5py def _import_h5io_funcs(): h5io = _soft_import("h5io", "HDF5-based I/O") return h5io.read_hdf5, h5io.write_hdf5 def _import_pymatreader_funcs(purpose): pymatreader = _soft_import("pymatreader", purpose) return pymatreader.read_mat # adapted from scikit-learn utils/validation.py def check_random_state(seed): """Turn seed into a numpy.random.mtrand.RandomState instance. If seed is None, return the RandomState singleton used by np.random.mtrand. If seed is an int, return a new RandomState instance seeded with seed. If seed is already a RandomState instance, return it. Otherwise raise ValueError. """ if seed is None or seed is np.random: return np.random.mtrand._rand if isinstance(seed, (int, np.integer)): return np.random.mtrand.RandomState(seed) if isinstance(seed, np.random.mtrand.RandomState): return seed if isinstance(seed, np.random.Generator): return seed raise ValueError( f"{seed!r} cannot be used to seed a numpy.random.mtrand.RandomState instance" ) def _check_event_id(event_id, events): """Check event_id and convert to default format.""" # check out event_id dict if event_id is None: # convert to int to make typing-checks happy event_id = list(np.unique(events[:, 2])) if isinstance(event_id, dict): for key in event_id.keys(): _validate_type(key, str, "Event names") event_id = { key: _ensure_int(val, f"event_id[{key}]") for key, val in event_id.items() } elif isinstance(event_id, list): event_id = [_ensure_int(v, f"event_id[{vi}]") for vi, v in enumerate(event_id)] event_id = dict(zip((str(i) for i in event_id), event_id)) else: event_id = _ensure_int(event_id, "event_id") event_id = {str(event_id): event_id} return event_id @verbose def _check_fname( fname, overwrite=False, must_exist=False, name="File", need_dir=False, *, check_bids_split=False, verbose=None, ): """Check for file existence, and return its absolute path.""" _validate_type(fname, "path-like", name) # special case for MNE-BIDS, check split fname_path = Path(fname) if check_bids_split: try: from mne_bids import BIDSPath except Exception: pass else: if isinstance(fname, BIDSPath) and fname.split is not None: raise ValueError( f"Passing a BIDSPath {name} with `{fname.split=}` is unsafe as it " "can unexpectedly lead to invalid BIDS split naming. Explicitly " f"set `{name}.split = None` to avoid ambiguity. If you want the " f"old misleading split naming, you can pass `str({name})`." ) fname = fname_path.expanduser().absolute() del fname_path if fname.exists(): if not overwrite: raise FileExistsError( "Destination file exists. Please use option " '"overwrite=True" to force overwriting.' ) elif overwrite != "read": logger.info("Overwriting existing file.") if must_exist: if need_dir: if not fname.is_dir(): raise OSError( f"Need a directory for {name} but found a file at {fname}" ) else: if not fname.is_file(): raise OSError( f"Need a file for {name} but found a directory at {fname}" ) if not os.access(fname, os.R_OK): raise PermissionError(f"{name} does not have read permissions: {fname}") elif must_exist: raise FileNotFoundError(f'{name} does not exist: "{fname}"') return fname def _check_subject( first, second, *, raise_error=True, first_kind="class subject attribute", second_kind="input subject", ): """Get subject name from class.""" if second is not None: _validate_type(second, "str", "subject input") if first is not None and first != second: raise ValueError( f"{first_kind} ({repr(first)}) did not match " f"{second_kind} ({second})" ) return second elif first is not None: _validate_type(first, "str", f"Either {second_kind} subject or {first_kind}") return first elif raise_error is True: raise ValueError(f"Neither {second_kind} subject nor {first_kind} was a string") return None def _check_preload(inst, msg): """Ensure data are preloaded.""" from ..epochs import BaseEpochs from ..evoked import Evoked from ..source_estimate import _BaseSourceEstimate from ..time_frequency import BaseTFR from ..time_frequency.spectrum import BaseSpectrum if isinstance(inst, (BaseTFR, Evoked, BaseSpectrum, _BaseSourceEstimate)): pass else: name = "epochs" if isinstance(inst, BaseEpochs) else "raw" if not inst.preload: raise RuntimeError( "By default, MNE does not load data into main memory to " "conserve resources. " + msg + f" requires {name} data to be " "loaded. Use preload=True (or string) in the constructor or " f"{name}.load_data()." ) if name == "epochs": inst._handle_empty("raise", msg) def _check_compensation_grade(info1, info2, name1, name2="data", ch_names=None): """Ensure that objects have same compensation_grade.""" from .._fiff.compensator import get_current_comp from .._fiff.meas_info import Info from .._fiff.pick import pick_channels, pick_info for t_info in (info1, info2): if t_info is None: return assert isinstance(t_info, Info), t_info # or internal code is wrong if ch_names is not None: info1 = info1.copy() info2 = info2.copy() # pick channels for t_info in [info1, info2]: if t_info["comps"]: with t_info._unlock(): t_info["comps"] = [] picks = pick_channels(t_info["ch_names"], ch_names, ordered=False) pick_info(t_info, picks, copy=False) # "or 0" here aliases None -> 0, as they are equivalent grade1 = get_current_comp(info1) or 0 grade2 = get_current_comp(info2) or 0 # perform check if grade1 != grade2: raise RuntimeError( f"Compensation grade of {name1} ({grade1}) and {name2} ({grade2}) " "do not match" ) def _soft_import(name, purpose, strict=True): """Import soft dependencies, providing informative errors on failure. Parameters ---------- name : str Name of the module to be imported. For example, 'pandas'. purpose : str A very brief statement (formulated as a noun phrase) explaining what functionality the package provides to MNE-Python. strict : bool Whether to raise an error if module import fails. """ # so that error msg lines are aligned def indent(x): return x.rjust(len(x) + 14) # Mapping import namespaces to their pypi package name pip_name = dict( sklearn="scikit-learn", mne_bids="mne-bids", mne_nirs="mne-nirs", mne_features="mne-features", mne_qt_browser="mne-qt-browser", mne_connectivity="mne-connectivity", mne_gui_addons="mne-gui-addons", pyvista="pyvistaqt", ).get(name, name) try: mod = import_module(name) return mod except (ImportError, ModuleNotFoundError): if strict: raise RuntimeError( f"For {purpose} to work, the {name} module is needed, " + "but it could not be imported.\n" + "\n".join( ( indent( "use the following installation method " "appropriate for your environment:" ), indent(f"'pip install {pip_name}'"), indent(f"'conda install -c conda-forge {pip_name}'"), ) ) ) else: return False def _check_pandas_installed(strict=True): """Aux function.""" return _soft_import("pandas", "dataframe integration", strict=strict) def _check_eeglabio_installed(strict=True): """Aux function.""" return _soft_import("eeglabio", "exporting to EEGLab", strict=strict) def _check_edfio_installed(strict=True): """Aux function.""" return _soft_import("edfio", "exporting to EDF", strict=strict) def _check_pybv_installed(strict=True): """Aux function.""" return _soft_import("pybv", "exporting to BrainVision", strict=strict) def _check_pymatreader_installed(strict=True): """Aux function.""" return _soft_import("pymatreader", "loading v7.3 (HDF5) .MAT files", strict=strict) def _check_pandas_index_arguments(index, valid): """Check pandas index arguments.""" if index is None: return if isinstance(index, str): index = [index] if not isinstance(index, list): raise TypeError( "index must be `None` or a string or list of strings, got type " f"{type(index)}." ) invalid = set(index) - set(valid) if invalid: plural = ("is not a valid option", "are not valid options")[ int(len(invalid) > 1) ] raise ValueError( '"{}" {}. Valid index options are `None`, "{}".'.format( '", "'.join(invalid), plural, '", "'.join(valid) ) ) return index def _check_time_format(time_format, valid, meas_date=None): """Check time_format argument.""" if time_format not in valid and time_format is not None: valid_str = '", "'.join(valid) raise ValueError( f'"{time_format}" is not a valid time format. Valid options are ' f'"{valid_str}" and None.' ) # allow datetime only if meas_date available if time_format == "datetime" and meas_date is None: warn( "Cannot convert to Datetime when raw.info['meas_date'] is " "None. Falling back to Timedelta." ) time_format = "timedelta" return time_format def _check_ch_locs(info, picks=None, ch_type=None): """Check if channel locations exist. Parameters ---------- info : Info | None `~mne.Info` instance. picks : list of int Channel indices to consider. If provided, ``ch_type`` must be ``None``. ch_type : str | None The channel type to restrict the check to. If ``None``, check all channel types. If provided, ``picks`` must be ``None``. """ from .._fiff.pick import _picks_to_idx, pick_info if picks is not None and ch_type is not None: raise ValueError("Either picks or ch_type may be provided, not both") if picks is not None: info = pick_info(info=info, sel=picks) elif ch_type is not None: picks = _picks_to_idx(info=info, picks=ch_type, none=ch_type) info = pick_info(info=info, sel=picks) chs = info["chs"] locs3d = np.array([ch["loc"][:3] for ch in chs]) return not ( (locs3d == 0).all() or (~np.isfinite(locs3d)).all() or np.allclose(locs3d, 0.0) ) def _is_numeric(n): return isinstance(n, numbers.Number) class _IntLike: @classmethod def __instancecheck__(cls, other): try: _ensure_int(other) except TypeError: return False else: return True int_like = _IntLike() path_like = (str, Path, os.PathLike) class _Callable: @classmethod def __instancecheck__(cls, other): return callable(other) class _Sparse: @classmethod def __instancecheck__(cls, other): from scipy import sparse return sparse.issparse(other) _multi = { "str": (str,), "numeric": (np.floating, float, int_like), "path-like": path_like, "int-like": (int_like,), "callable": (_Callable(),), "array-like": (list, tuple, set, np.ndarray), "sparse": (_Sparse(),), } def _validate_type(item, types=None, item_name=None, type_name=None, *, extra=""): """Validate that `item` is an instance of `types`. Parameters ---------- item : object The thing to be checked. types : type | str | tuple of types | tuple of str The types to be checked against. If str, must be one of {'int', 'int-like', 'str', 'numeric', 'info', 'path-like', 'callable', 'array-like'}. If a tuple of str is passed, use 'int-like' and not 'int' for integers. item_name : str | None Name of the item to show inside the error message. type_name : str | None Possible types to show inside the error message that the checked item can be. extra : str Extra text to append to the warning. """ if types == "int": _ensure_int(item, name=item_name, extra=extra) return # terminate prematurely elif types == "info": from .._fiff.meas_info import Info as types if not isinstance(types, (list, tuple)): types = [types] check_types = sum( ( (type(None),) if type_ is None else (type_,) if not isinstance(type_, str) else _multi[type_] for type_ in types ), (), ) extra = f" {extra}" if extra else extra if not isinstance(item, check_types): if type_name is None: type_name = [ "None" if cls_ is None else cls_.__name__ if not isinstance(cls_, str) else cls_ for cls_ in types ] if len(type_name) == 1: type_name = type_name[0] elif len(type_name) == 2: type_name = " or ".join(type_name) else: type_name[-1] = "or " + type_name[-1] type_name = ", ".join(type_name) _item_name = "Item" if item_name is None else item_name raise TypeError( f"{_item_name} must be an instance of {type_name}{extra}, " f"got {type(item)} instead." ) def _check_range(val, min_val, max_val, name, min_inclusive=True, max_inclusive=True): """Check that item is within range. Parameters ---------- val : int | float The value to be checked. min_val : int | float The minimum value allowed. max_val : int | float The maximum value allowed. name : str The name of the value. min_inclusive : bool Whether ``val`` is allowed to be ``min_val``. max_inclusive : bool Whether ``val`` is allowed to be ``max_val``. """ below_min = val < min_val if min_inclusive else val <= min_val above_max = val > max_val if max_inclusive else val >= max_val if below_min or above_max: error_str = f"The value of {name} must be between {min_val} " if min_inclusive: error_str += "inclusive " error_str += f"and {max_val}" if max_inclusive: error_str += "inclusive " raise ValueError(error_str) def _path_like(item): """Validate that `item` is `path-like`. Parameters ---------- item : object The thing to be checked. Returns ------- bool ``True`` if `item` is a `path-like` object; ``False`` otherwise. """ try: _validate_type(item, types="path-like") return True except TypeError: return False def _check_if_nan(data, msg=" to be plotted"): """Raise if any of the values are NaN.""" if not np.isfinite(data).all(): raise ValueError(f"Some of the values {msg} are NaN.") @verbose def _check_info_inv(info, forward, data_cov=None, noise_cov=None, verbose=None): """Return good channels common to forward model and covariance matrices.""" from .._fiff.pick import pick_types # get a list of all channel names: fwd_ch_names = forward["info"]["ch_names"] # handle channels from forward model and info: ch_names = _compare_ch_names(info["ch_names"], fwd_ch_names, info["bads"]) # make sure that no reference channels are left: ref_chs = pick_types(info, meg=False, ref_meg=True) ref_chs = [info["ch_names"][ch] for ch in ref_chs] ch_names = [ch for ch in ch_names if ch not in ref_chs] # inform about excluding channels: if ( data_cov is not None and set(info["bads"]) != set(data_cov["bads"]) and (len(set(ch_names).intersection(data_cov["bads"])) > 0) ): logger.info( 'info["bads"] and data_cov["bads"] do not match, ' "excluding bad channels from both." ) if ( noise_cov is not None and set(info["bads"]) != set(noise_cov["bads"]) and (len(set(ch_names).intersection(noise_cov["bads"])) > 0) ): logger.info( 'info["bads"] and noise_cov["bads"] do not match, ' "excluding bad channels from both." ) # handle channels from data cov if data cov is not None # Note: data cov is supposed to be None in tf_lcmv if data_cov is not None: ch_names = _compare_ch_names(ch_names, data_cov.ch_names, data_cov["bads"]) # handle channels from noise cov if noise cov available: if noise_cov is not None: ch_names = _compare_ch_names(ch_names, noise_cov.ch_names, noise_cov["bads"]) # inform about excluding any channels apart from bads and reference all_bads = info["bads"] + ref_chs if data_cov is not None: all_bads += data_cov["bads"] if noise_cov is not None: all_bads += noise_cov["bads"] dropped_nonbads = set(info["ch_names"]) - set(ch_names) - set(all_bads) if dropped_nonbads: logger.info( f"Excluding {len(dropped_nonbads)} channel(s) missing from the " "provided forward operator and/or covariance matrices" ) picks = [info["ch_names"].index(k) for k in ch_names if k in info["ch_names"]] return picks def _compare_ch_names(names1, names2, bads): """Return channel names of common and good channels.""" ch_names = [ch for ch in names1 if ch not in bads and ch in names2] return ch_names def _check_channels_spatial_filter(ch_names, filters): """Return data channel indices to be used with spatial filter. Unlike ``pick_channels``, this respects the order of ch_names. """ sel = [] # first check for channel discrepancies between filter and data: for ch_name in filters["ch_names"]: if ch_name not in ch_names: raise ValueError( f"The spatial filter was computed with channel {ch_name} " "which is not present in the data. You should " "compute a new spatial filter restricted to the " "good data channels." ) # then compare list of channels and get selection based on data: sel = [ii for ii, ch_name in enumerate(ch_names) if ch_name in filters["ch_names"]] return sel def _check_rank(rank): """Check rank parameter.""" _validate_type(rank, (None, dict, str), "rank") if isinstance(rank, str): if rank not in ["full", "info"]: raise ValueError(f'rank, if str, must be "full" or "info", got {rank}') return rank def _check_one_ch_type(method, info, forward, data_cov=None, noise_cov=None): """Check number of sensor types and presence of noise covariance matrix.""" from .._fiff.pick import _contains_ch_type, pick_info from ..cov import Covariance, make_ad_hoc_cov from ..time_frequency.csd import CrossSpectralDensity if isinstance(data_cov, CrossSpectralDensity): _validate_type(noise_cov, [None, CrossSpectralDensity], "noise_cov") # FIXME picks = list(range(len(data_cov.ch_names))) info_pick = info else: _validate_type(noise_cov, [None, Covariance], "noise_cov") picks = _check_info_inv( info, forward, data_cov=data_cov, noise_cov=noise_cov, verbose=_verbose_safe_false(), ) info_pick = pick_info(info, picks) ch_types = [_contains_ch_type(info_pick, tt) for tt in ("mag", "grad", "eeg")] if sum(ch_types) > 1: if noise_cov is None: raise ValueError( "Source reconstruction with several sensor types" " requires a noise covariance matrix to be " "able to apply whitening." ) if noise_cov is None: noise_cov = make_ad_hoc_cov(info_pick, std=1.0) allow_mismatch = True else: noise_cov = noise_cov.copy() if isinstance(noise_cov, Covariance) and "estimator" in noise_cov: del noise_cov["estimator"] allow_mismatch = False _validate_type(noise_cov, (Covariance, CrossSpectralDensity), "noise_cov") return noise_cov, picks, allow_mismatch def _check_depth(depth, kind="depth_mne"): """Check depth options.""" if not isinstance(depth, dict): depth = dict(exp=None if depth is None else float(depth)) return _handle_default(kind, depth) def _check_dict_keys(mapping, valid_keys, key_description, valid_key_source): """Check that the keys in dictionary are valid against a set list. Return the input dictionary if it is valid, otherwise raise a ValueError with a readable error message. Parameters ---------- mapping : dict The user-provided dict whose keys we want to check. valid_keys : iterable The valid keys. key_description : str Description of the keys in ``mapping``, e.g., "channel name(s)" or "annotation(s)". valid_key_source : str Description of the ``valid_keys`` source, e.g., "info dict" or "annotations in the data". Returns ------- mapping If all keys are valid the input dict is returned unmodified. """ missing = set(mapping) - set(valid_keys) if len(missing): _is = "are" if len(missing) > 1 else "is" msg = ( f"Invalid {key_description} {missing} {_is} not present in " f"{valid_key_source}" ) raise ValueError(msg) return mapping def _check_option(parameter, value, allowed_values, extra=""): """Check the value of a parameter against a list of valid options. Return the value if it is valid, otherwise raise a ValueError with a readable error message. Parameters ---------- parameter : str The name of the parameter to check. This is used in the error message. value : any type The value of the parameter to check. allowed_values : list The list of allowed values for the parameter. extra : str Extra string to append to the invalid value sentence, e.g. "when using ico mode". Raises ------ ValueError When the value of the parameter is not one of the valid options. Returns ------- value : any type The value if it is valid. """ if value in allowed_values: return value # Prepare a nice error message for the user extra = f" {extra}" if extra else extra msg = ( "Invalid value for the '{parameter}' parameter{extra}. " "{options}, but got {value!r} instead." ) allowed_values = list(allowed_values) # e.g., if a dict was given if len(allowed_values) == 1: options = f"The only allowed value is {repr(allowed_values[0])}" else: options = "Allowed values are " if len(allowed_values) == 2: options += " and ".join(repr(v) for v in allowed_values) else: options += ", ".join(repr(v) for v in allowed_values[:-1]) options += f", and {repr(allowed_values[-1])}" raise ValueError( msg.format(parameter=parameter, options=options, value=value, extra=extra) ) def _check_all_same_channel_names(instances): """Check if a collection of instances all have the same channels.""" ch_names = instances[0].info["ch_names"] for inst in instances: if ch_names != inst.info["ch_names"]: return False return True def _check_combine(mode, valid=("mean", "median", "std"), axis=0): # XXX TODO Possibly de-duplicate with _make_combine_callable of mne/viz/utils.py if mode == "mean": def fun(data): return np.mean(data, axis=axis) elif mode == "std": def fun(data): return np.std(data, axis=axis) elif mode == "median" or mode == np.median: def fun(data): return _median_complex(data, axis=axis) elif callable(mode): fun = mode else: raise ValueError( "Combine option must be " + ", ".join(valid) + f" or callable, got {mode} (type {type(mode)})." ) return fun def _check_src_normal(pick_ori, src): from ..source_space import SourceSpaces _validate_type(src, SourceSpaces, "src") if pick_ori == "normal" and src.kind not in ("surface", "discrete"): raise RuntimeError( "Normal source orientation is supported only for " "surface or discrete SourceSpaces, got type " f"{src.kind}" ) def _check_stc_units(stc, threshold=1e-7): # 100 nAm threshold for warning max_cur = np.max(np.abs(stc.data)) if max_cur > threshold: warn( f"The maximum current magnitude is {1e9 * max_cur:.1f} nAm, which is very " "large. Are you trying to apply the forward model to noise-normalized " "(dSPM, sLORETA, or eLORETA) values? The result will only be " "correct if currents (in units of Am) are used." ) def _check_qt_version(*, return_api=False, check_usable_display=True): """Check if Qt is installed.""" from ..viz.backends._utils import _init_mne_qtapp try: from qtpy import API_NAME as api from qtpy import QtCore except Exception: api = version = None else: try: # pyside version = QtCore.__version__ except AttributeError: version = QtCore.QT_VERSION_STR # Having Qt installed is not enough -- sometimes the app is unusable # for example because there is no usable display (e.g., on a server), # so we have to try instantiating one to actually know. if check_usable_display: try: _init_mne_qtapp() except Exception: api = version = None if return_api: return version, api else: return version def _check_sphere(sphere, info=None, sphere_units="m"): from ..bem import ConductorModel, fit_sphere_to_headshape, get_fitting_dig if sphere is None: sphere = HEAD_SIZE_DEFAULT if info is not None: # Decide if we have enough dig points to do the auto fit try: get_fitting_dig(info, "extra", verbose="error") except (RuntimeError, ValueError): pass else: sphere = "auto" if isinstance(sphere, str): if sphere not in ("auto", "eeglab"): raise ValueError( f'sphere, if str, must be "auto" or "eeglab", got {sphere}' ) assert info is not None if sphere == "auto": R, r0, _ = fit_sphere_to_headshape( info, verbose=_verbose_safe_false(), units="m" ) sphere = tuple(r0) + (R,) sphere_units = "m" elif sphere == "eeglab": # We need coordinates for the 2D plane formed by # Fpz<->Oz and T7<->T8, as this plane will be the horizon (i.e. it # will determine the location of the head circle). # # We implement some special-handling in case Fpz is missing, as # this seems to be a quite common situation in numerous EEG labs. montage = info.get_montage() if montage is None: raise ValueError( 'No montage was set on your data, but sphere="eeglab" ' "can only work if digitization points for the EEG " "channels are available. Consider calling set_montage() " "to apply a montage." ) ch_pos = montage.get_positions()["ch_pos"] horizon_ch_names = ("Fpz", "Oz", "T7", "T8") if "FPz" in ch_pos: # "fix" naming ch_pos["Fpz"] = ch_pos["FPz"] del ch_pos["FPz"] elif "Fpz" not in ch_pos and "Oz" in ch_pos: logger.info( "Approximating Fpz location by mirroring Oz along " "the X and Y axes." ) # This assumes Fpz and Oz have the same Z coordinate ch_pos["Fpz"] = ch_pos["Oz"] * [-1, -1, 1] for ch_name in horizon_ch_names: if ch_name not in ch_pos: msg = ( f'sphere="eeglab" requires digitization points of ' f"the following electrode locations in the data: " f'{", ".join(horizon_ch_names)}, but could not find: ' f"{ch_name}" ) if ch_name == "Fpz": msg += ", and was unable to approximate its location from Oz" raise ValueError(msg) # Calculate the radius from: T7<->T8, Fpz<->Oz radius = np.abs( [ ch_pos["T7"][0], # X axis ch_pos["T8"][0], # X axis ch_pos["Fpz"][1], # Y axis ch_pos["Oz"][1], # Y axis ] ).mean() # Calculate the center of the head sphere # Use 4 digpoints for each of the 3 axes to hopefully get a better # approximation than when using just 2 digpoints. sphere_locs = dict() for idx, axis in enumerate(("X", "Y", "Z")): sphere_locs[axis] = np.mean( [ ch_pos["T7"][idx], ch_pos["T8"][idx], ch_pos["Fpz"][idx], ch_pos["Oz"][idx], ] ) sphere = (sphere_locs["X"], sphere_locs["Y"], sphere_locs["Z"], radius) sphere_units = "m" del sphere_locs, radius, montage, ch_pos elif isinstance(sphere, ConductorModel): if not sphere["is_sphere"] or len(sphere["layers"]) == 0: raise ValueError( "sphere, if a ConductorModel, must be spherical " "with multiple layers, not a BEM or single-layer " f"sphere (got {sphere})" ) sphere = tuple(sphere["r0"]) + (sphere["layers"][0]["rad"],) sphere_units = "m" sphere = np.array(sphere, dtype=float) if sphere.shape == (): sphere = np.concatenate([[0.0] * 3, [sphere]]) if sphere.shape != (4,): raise ValueError( "sphere must be float or 1D array of shape (4,), got " f"array-like of shape {sphere.shape}" ) _check_option("sphere_units", sphere_units, ("m", "mm")) if sphere_units == "mm": sphere /= 1000.0 sphere = np.array(sphere, float) return sphere def _check_head_radius(radius, add_info=""): """Check that head radius is within a reasonable range (5. - 10.85 cm). Parameters ---------- radius : float Head radius in meters. add_info : str Additional info to add to the warning message. Notes ----- The maximum value was taken from the head size percentiles given in the following Wikipedia infographic: https://upload.wikimedia.org/wikipedia/commons/0/06/AvgHeadSizes.png the maximum radius is taken from the 99th percentile for men Glabella to back of the head measurements (Glabella is a point just above the Nasion): 21.7cm / 2 = 10.85 cm = 0.1085 m The minimum value was taken from The National Center for Health Statistics (USA) infant head circumference percentiles: https://www.cdc.gov/growthcharts/html_charts/hcageinf.htm we take the minimum to be the radius corresponding to the 3rd percentile head circumference of female 0-month infant, rounded down: 31.9302 cm circumference / (2 * pi) = 5.08 cm radius -> 0.05 m """ min_radius = 0.05 max_radius = 0.1085 if radius > max_radius: msg = ( f"Estimated head radius ({1e2 * radius:0.1f} cm) is " "above the 99th percentile for adult head size." ) warn(msg + add_info) elif radius < min_radius: msg = ( f"Estimated head radius ({1e2 * radius:0.1f} cm) is " "below the 3rd percentile for infant head size." ) warn(msg + add_info) def _check_freesurfer_home(): from .config import get_config fs_home = get_config("FREESURFER_HOME") if fs_home is None: raise RuntimeError("The FREESURFER_HOME environment variable is not set.") return fs_home def _suggest(val, options, cutoff=0.66): options = get_close_matches(val, options, cutoff=cutoff) if len(options) == 0: return "" elif len(options) == 1: return f" Did you mean {repr(options[0])}?" else: return f" Did you mean one of {repr(options)}?" def _check_on_missing(on_missing, name="on_missing", *, extras=()): _validate_type(on_missing, str, name) _check_option(name, on_missing, ["raise", "warn", "ignore"] + list(extras)) def _on_missing(on_missing, msg, name="on_missing", error_klass=None): _check_on_missing(on_missing, name) error_klass = ValueError if error_klass is None else error_klass on_missing = "raise" if on_missing == "error" else on_missing on_missing = "warn" if on_missing == "warning" else on_missing if on_missing == "raise": raise error_klass(msg) elif on_missing == "warn": warn(msg) else: # Ignore assert on_missing == "ignore" def _safe_input(msg, *, alt=None, use=None): try: return input(msg) except EOFError: # MATLAB or other non-stdin if use is not None: return use raise RuntimeError( f"Could not use input() to get a response to:\n{msg}\n" f"You can {alt} to avoid this error." ) def _ensure_events(events): err_msg = f"events should be a NumPy array of integers, got {type(events)}" with _record_warnings(): try: events = np.asarray(events) except ValueError as np_err: if str(np_err).startswith( "setting an array element with a sequence. The requested " "array has an inhomogeneous shape" ): raise TypeError(err_msg) from None else: raise if not np.issubdtype(events.dtype, np.integer): raise TypeError(err_msg) if events.ndim != 2 or events.shape[1] != 3: raise ValueError(f"events must be of shape (N, 3), got {events.shape}") return events def _to_rgb(*args, name="color", alpha=False): from matplotlib.colors import colorConverter func = colorConverter.to_rgba if alpha else colorConverter.to_rgb try: return func(*args) except ValueError: args = args[0] if len(args) == 1 else args raise ValueError( f'Invalid RGB{"A" if alpha else ""} argument(s) for {name}: ' f"{repr(args)}" ) from None def _import_nibabel(why="use MRI files"): try: import nibabel as nib except ImportError as exp: raise exp.__class__(f"nibabel is required to {why}, got:\n{exp}") from None return nib def _check_method_kwargs(func, kwargs, msg=None): """Ensure **kwargs are compatible with the function they're passed to.""" from .misc import _pl valid = list(signature(func).parameters) is_invalid = np.isin(list(kwargs), valid, invert=True) if is_invalid.any(): invalid_kw = np.array(list(kwargs))[is_invalid].tolist() s = _pl(invalid_kw) if msg is None: msg = f'function "{func}"' raise TypeError( f'Got unexpected keyword argument{s} {", ".join(invalid_kw)} for {msg}.' )