273 lines
9.5 KiB
Python
273 lines
9.5 KiB
Python
# Authors: The MNE-Python contributors.
|
|
# License: BSD-3-Clause
|
|
# Copyright the MNE-Python contributors.
|
|
|
|
import warnings
|
|
from functools import partial
|
|
|
|
import numpy as np
|
|
from scipy.signal import spectrogram
|
|
|
|
from ..parallel import parallel_func
|
|
from ..utils import _check_option, _ensure_int, logger, verbose
|
|
from ..utils.numerics import _mask_to_onsets_offsets
|
|
|
|
|
|
# adapted from SciPy
|
|
# https://github.com/scipy/scipy/blob/f71e7fad717801c4476312fe1e23f2dfbb4c9d7f/scipy/signal/_spectral_py.py#L2019 # noqa: E501
|
|
def _median_biases(n):
|
|
# Compute the biases for 0 to max(n, 1) terms included in a median calc
|
|
biases = np.ones(n + 1)
|
|
# The original SciPy code is:
|
|
#
|
|
# def _median_bias(n):
|
|
# ii_2 = 2 * np.arange(1., (n - 1) // 2 + 1)
|
|
# return 1 + np.sum(1. / (ii_2 + 1) - 1. / ii_2)
|
|
#
|
|
# This is a sum over (n-1)//2 terms.
|
|
# The ii_2 terms here for different n are:
|
|
#
|
|
# n=0: [] # 0 terms
|
|
# n=1: [] # 0 terms
|
|
# n=2: [] # 0 terms
|
|
# n=3: [2] # 1 term
|
|
# n=4: [2] # 1 term
|
|
# n=5: [2, 4] # 2 terms
|
|
# n=6: [2, 4] # 2 terms
|
|
# ...
|
|
#
|
|
# We can get the terms for 0 through n using a cumulative summation and
|
|
# indexing:
|
|
if n >= 3:
|
|
ii_2 = 2 * np.arange(1, (n - 1) // 2 + 1)
|
|
sums = 1 + np.cumsum(1.0 / (ii_2 + 1) - 1.0 / ii_2)
|
|
idx = np.arange(2, n) // 2 - 1
|
|
biases[3:] = sums[idx]
|
|
return biases
|
|
|
|
|
|
def _decomp_aggregate_mask(epoch, func, average, freq_sl):
|
|
_, _, spect = func(epoch)
|
|
spect = spect[..., freq_sl, :]
|
|
# Do the averaging here (per epoch) to save memory
|
|
if average == "mean":
|
|
spect = np.nanmean(spect, axis=-1)
|
|
elif average == "median":
|
|
biases = _median_biases(spect.shape[-1])
|
|
idx = (~np.isnan(spect)).sum(-1)
|
|
spect = np.nanmedian(spect, axis=-1) / biases[idx]
|
|
return spect
|
|
|
|
|
|
def _spect_func(epoch, func, freq_sl, average, *, output="power"):
|
|
"""Aux function."""
|
|
# Decide if we should split this to save memory or not, since doing
|
|
# multiple calls will incur some performance overhead. Eventually we might
|
|
# want to write (really, go back to) our own spectrogram implementation
|
|
# that, if possible, averages after each transform, but this will incur
|
|
# a lot of overhead because of the many Python calls required.
|
|
kwargs = dict(func=func, average=average, freq_sl=freq_sl)
|
|
if epoch.nbytes > 10e6:
|
|
spect = np.apply_along_axis(_decomp_aggregate_mask, -1, epoch, **kwargs)
|
|
else:
|
|
spect = _decomp_aggregate_mask(epoch, **kwargs)
|
|
return spect
|
|
|
|
|
|
def _check_nfft(n, n_fft, n_per_seg, n_overlap):
|
|
"""Ensure n_fft, n_per_seg and n_overlap make sense."""
|
|
if n_per_seg is None and n_fft > n:
|
|
raise ValueError(
|
|
"If n_per_seg is None n_fft is not allowed to be > "
|
|
"n_times. If you want zero-padding, you have to set "
|
|
f"n_per_seg to relevant length. Got n_fft of {n_fft} while"
|
|
f" signal length is {n}."
|
|
)
|
|
n_per_seg = n_fft if n_per_seg is None or n_per_seg > n_fft else n_per_seg
|
|
n_per_seg = n if n_per_seg > n else n_per_seg
|
|
if n_overlap >= n_per_seg:
|
|
raise ValueError(
|
|
"n_overlap cannot be greater than n_per_seg (or n_fft). Got n_overlap "
|
|
f"of {n_overlap} while n_per_seg is {n_per_seg}."
|
|
)
|
|
return n_fft, n_per_seg, n_overlap
|
|
|
|
|
|
@verbose
|
|
def psd_array_welch(
|
|
x,
|
|
sfreq,
|
|
fmin=0,
|
|
fmax=np.inf,
|
|
n_fft=256,
|
|
n_overlap=0,
|
|
n_per_seg=None,
|
|
n_jobs=None,
|
|
average="mean",
|
|
window="hamming",
|
|
remove_dc=True,
|
|
*,
|
|
output="power",
|
|
verbose=None,
|
|
):
|
|
"""Compute power spectral density (PSD) using Welch's method.
|
|
|
|
Welch's method is described in :footcite:t:`Welch1967`.
|
|
|
|
Parameters
|
|
----------
|
|
x : array, shape=(..., n_times)
|
|
The data to compute PSD from.
|
|
sfreq : float
|
|
The sampling frequency.
|
|
fmin : float
|
|
The lower frequency of interest.
|
|
fmax : float
|
|
The upper frequency of interest.
|
|
n_fft : int
|
|
The length of FFT used, must be ``>= n_per_seg`` (default: 256).
|
|
The segments will be zero-padded if ``n_fft > n_per_seg``.
|
|
n_overlap : int
|
|
The number of points of overlap between segments. Will be adjusted
|
|
to be <= n_per_seg. The default value is 0.
|
|
n_per_seg : int | None
|
|
Length of each Welch segment (windowed with a Hamming window). Defaults
|
|
to None, which sets n_per_seg equal to n_fft.
|
|
%(n_jobs)s
|
|
%(average_psd)s
|
|
|
|
.. versionadded:: 0.19.0
|
|
%(window_psd)s
|
|
|
|
.. versionadded:: 0.22.0
|
|
%(remove_dc)s
|
|
|
|
output : str
|
|
The format of the returned ``psds`` array, ``'complex'`` or
|
|
``'power'``:
|
|
|
|
* ``'power'`` : the power spectral density is returned.
|
|
* ``'complex'`` : the complex fourier coefficients are returned per
|
|
window.
|
|
|
|
.. versionadded:: 1.4.0
|
|
%(verbose)s
|
|
|
|
Returns
|
|
-------
|
|
psds : ndarray, shape (..., n_freqs) or (..., n_freqs, n_segments)
|
|
The power spectral densities. If ``average='mean`` or
|
|
``average='median'``, the returned array will have the same shape
|
|
as the input data plus an additional frequency dimension.
|
|
If ``average=None``, the returned array will have the same shape as
|
|
the input data plus two additional dimensions corresponding to
|
|
frequencies and the unaggregated segments, respectively.
|
|
freqs : ndarray, shape (n_freqs,)
|
|
The frequencies.
|
|
|
|
Notes
|
|
-----
|
|
.. versionadded:: 0.14.0
|
|
|
|
References
|
|
----------
|
|
.. footbibliography::
|
|
"""
|
|
_check_option("average", average, (None, False, "mean", "median"))
|
|
_check_option("output", output, ("power", "complex"))
|
|
detrend = "constant" if remove_dc else False
|
|
mode = "complex" if output == "complex" else "psd"
|
|
n_fft = _ensure_int(n_fft, "n_fft")
|
|
n_overlap = _ensure_int(n_overlap, "n_overlap")
|
|
if n_per_seg is not None:
|
|
n_per_seg = _ensure_int(n_per_seg, "n_per_seg")
|
|
if average is False:
|
|
average = None
|
|
|
|
dshape = x.shape[:-1]
|
|
n_times = x.shape[-1]
|
|
x = x.reshape(-1, n_times)
|
|
|
|
# Prep the PSD
|
|
n_fft, n_per_seg, n_overlap = _check_nfft(n_times, n_fft, n_per_seg, n_overlap)
|
|
win_size = n_fft / float(sfreq)
|
|
logger.info(f"Effective window size : {win_size:0.3f} (s)")
|
|
freqs = np.arange(n_fft // 2 + 1, dtype=float) * (sfreq / n_fft)
|
|
freq_mask = (freqs >= fmin) & (freqs <= fmax)
|
|
if not freq_mask.any():
|
|
raise ValueError(f"No frequencies found between fmin={fmin} and fmax={fmax}")
|
|
freq_sl = slice(*(np.where(freq_mask)[0][[0, -1]] + [0, 1]))
|
|
del freq_mask
|
|
freqs = freqs[freq_sl]
|
|
|
|
# Parallelize across first N-1 dimensions
|
|
logger.debug(
|
|
f"Spectogram using {n_fft}-point FFT on {n_per_seg} samples with "
|
|
f"{n_overlap} overlap and {window} window"
|
|
)
|
|
|
|
parallel, my_spect_func, n_jobs = parallel_func(_spect_func, n_jobs=n_jobs)
|
|
_func = partial(
|
|
spectrogram,
|
|
detrend=detrend,
|
|
noverlap=n_overlap,
|
|
nperseg=n_per_seg,
|
|
nfft=n_fft,
|
|
fs=sfreq,
|
|
window=window,
|
|
mode=mode,
|
|
)
|
|
if np.any(np.isnan(x)):
|
|
good_mask = ~np.isnan(x)
|
|
# NaNs originate from annot, so must match for all channels. Note that we CANNOT
|
|
# use np.testing.assert_allclose() here; it is strict about shapes/broadcasting
|
|
assert np.allclose(good_mask, good_mask[[0]], equal_nan=True)
|
|
t_onsets, t_offsets = _mask_to_onsets_offsets(good_mask[0])
|
|
x_splits = [x[..., t_ons:t_off] for t_ons, t_off in zip(t_onsets, t_offsets)]
|
|
# weights reflect the number of samples used from each span. For spans longer
|
|
# than `n_per_seg`, trailing samples may be discarded. For spans shorter than
|
|
# `n_per_seg`, the wrapped function (`scipy.signal.spectrogram`) automatically
|
|
# reduces `n_per_seg` to match the span length (with a warning).
|
|
step = n_per_seg - n_overlap
|
|
span_lengths = [span.shape[-1] for span in x_splits]
|
|
weights = [
|
|
w if w < n_per_seg else w - ((w - n_overlap) % step) for w in span_lengths
|
|
]
|
|
agg_func = partial(np.average, weights=weights)
|
|
if n_jobs > 1:
|
|
logger.info(
|
|
f"Data split into {len(x_splits)} (probably unequal) chunks due to "
|
|
'"bad_*" annotations. Parallelization may be sub-optimal.'
|
|
)
|
|
if (np.array(span_lengths) < n_per_seg).any():
|
|
logger.info(
|
|
"At least one good data span is shorter than n_per_seg, and will be "
|
|
"analyzed with a shorter window than the rest of the file."
|
|
)
|
|
|
|
def func(*args, **kwargs):
|
|
# swallow SciPy warnings caused by short good data spans
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings(
|
|
action="ignore",
|
|
module="scipy",
|
|
category=UserWarning,
|
|
message=r"nperseg = \d+ is greater than input length",
|
|
)
|
|
return _func(*args, **kwargs)
|
|
|
|
else:
|
|
x_splits = [arr for arr in np.array_split(x, n_jobs) if arr.size != 0]
|
|
agg_func = np.concatenate
|
|
func = _func
|
|
f_spect = parallel(
|
|
my_spect_func(d, func=func, freq_sl=freq_sl, average=average, output=output)
|
|
for d in x_splits
|
|
)
|
|
psds = agg_func(f_spect, axis=0)
|
|
shape = dshape + (len(freqs),)
|
|
if average is None:
|
|
shape = shape + (-1,)
|
|
psds.shape = shape
|
|
return psds, freqs
|