262 lines
6.7 KiB
Python
262 lines
6.7 KiB
Python
# Authors: The MNE-Python contributors.
|
|
# License: BSD-3-Clause
|
|
# Copyright the MNE-Python contributors.
|
|
|
|
from math import ceil
|
|
|
|
import numpy as np
|
|
from scipy.fft import irfft, rfft, rfftfreq
|
|
|
|
from ..utils import logger, verbose
|
|
|
|
|
|
@verbose
|
|
def stft(x, wsize, tstep=None, verbose=None):
|
|
"""STFT Short-Term Fourier Transform using a sine window.
|
|
|
|
The transformation is designed to be a tight frame that can be
|
|
perfectly inverted. It only returns the positive frequencies.
|
|
|
|
Parameters
|
|
----------
|
|
x : array, shape (n_signals, n_times)
|
|
Containing multi-channels signal.
|
|
wsize : int
|
|
Length of the STFT window in samples (must be a multiple of 4).
|
|
tstep : int
|
|
Step between successive windows in samples (must be a multiple of 2,
|
|
a divider of wsize and smaller than wsize/2) (default: wsize/2).
|
|
%(verbose)s
|
|
|
|
Returns
|
|
-------
|
|
X : array, shape (n_signals, wsize // 2 + 1, n_step)
|
|
STFT coefficients for positive frequencies with
|
|
``n_step = ceil(T / tstep)``.
|
|
|
|
See Also
|
|
--------
|
|
istft
|
|
stftfreq
|
|
"""
|
|
if not np.isrealobj(x):
|
|
raise ValueError("x is not a real valued array")
|
|
|
|
if x.ndim == 1:
|
|
x = x[None, :]
|
|
|
|
n_signals, T = x.shape
|
|
wsize = int(wsize)
|
|
|
|
# Errors and warnings
|
|
if wsize % 4:
|
|
raise ValueError("The window length must be a multiple of 4.")
|
|
|
|
if tstep is None:
|
|
tstep = wsize / 2
|
|
|
|
tstep = int(tstep)
|
|
|
|
if (wsize % tstep) or (tstep % 2):
|
|
raise ValueError(
|
|
"The step size must be a multiple of 2 and a "
|
|
"divider of the window length."
|
|
)
|
|
|
|
if tstep > wsize / 2:
|
|
raise ValueError("The step size must be smaller than half the window length.")
|
|
|
|
n_step = int(ceil(T / float(tstep)))
|
|
n_freq = wsize // 2 + 1
|
|
logger.info(f"Number of frequencies: {n_freq}")
|
|
logger.info(f"Number of time steps: {n_step}")
|
|
|
|
X = np.zeros((n_signals, n_freq, n_step), dtype=np.complex128)
|
|
|
|
if n_signals == 0:
|
|
return X
|
|
|
|
# Defining sine window
|
|
win = np.sin(np.arange(0.5, wsize + 0.5) / wsize * np.pi)
|
|
win2 = win**2
|
|
|
|
swin = np.zeros((n_step - 1) * tstep + wsize)
|
|
for t in range(n_step):
|
|
swin[t * tstep : t * tstep + wsize] += win2
|
|
swin = np.sqrt(wsize * swin)
|
|
|
|
# Zero-padding and Pre-processing for edges
|
|
xp = np.zeros((n_signals, wsize + (n_step - 1) * tstep), dtype=x.dtype)
|
|
xp[:, (wsize - tstep) // 2 : (wsize - tstep) // 2 + T] = x
|
|
x = xp
|
|
|
|
for t in range(n_step):
|
|
# Framing
|
|
wwin = win / swin[t * tstep : t * tstep + wsize]
|
|
frame = x[:, t * tstep : t * tstep + wsize] * wwin[None, :]
|
|
# FFT
|
|
X[:, :, t] = rfft(frame)
|
|
|
|
return X
|
|
|
|
|
|
def istft(X, tstep=None, Tx=None):
|
|
"""ISTFT Inverse Short-Term Fourier Transform using a sine window.
|
|
|
|
Parameters
|
|
----------
|
|
X : array, shape (..., wsize / 2 + 1, n_step)
|
|
The STFT coefficients for positive frequencies.
|
|
tstep : int
|
|
Step between successive windows in samples (must be a multiple of 2,
|
|
a divider of wsize and smaller than wsize/2) (default: wsize/2).
|
|
Tx : int
|
|
Length of returned signal. If None Tx = n_step * tstep.
|
|
|
|
Returns
|
|
-------
|
|
x : array, shape (Tx,)
|
|
Array containing the inverse STFT signal.
|
|
|
|
See Also
|
|
--------
|
|
stft
|
|
"""
|
|
# Errors and warnings
|
|
X = np.asarray(X)
|
|
if X.ndim < 2:
|
|
raise ValueError(f"X must have ndim >= 2, got {X.ndim}")
|
|
n_win, n_step = X.shape[-2:]
|
|
signal_shape = X.shape[:-2]
|
|
if n_win % 2 == 0:
|
|
raise ValueError("The number of rows of the STFT matrix must be odd.")
|
|
|
|
wsize = 2 * (n_win - 1)
|
|
if tstep is None:
|
|
tstep = wsize / 2
|
|
|
|
if wsize % tstep:
|
|
raise ValueError(
|
|
"The step size must be a divider of two times the "
|
|
"number of rows of the STFT matrix minus two."
|
|
)
|
|
|
|
if wsize % 2:
|
|
raise ValueError("The step size must be a multiple of 2.")
|
|
|
|
if tstep > wsize / 2:
|
|
raise ValueError(
|
|
"The step size must be smaller than the number of "
|
|
"rows of the STFT matrix minus one."
|
|
)
|
|
|
|
if Tx is None:
|
|
Tx = n_step * tstep
|
|
|
|
T = n_step * tstep
|
|
|
|
x = np.zeros(signal_shape + (T + wsize - tstep,), dtype=np.float64)
|
|
|
|
if np.prod(signal_shape) == 0:
|
|
return x[..., :Tx]
|
|
|
|
# Defining sine window
|
|
win = np.sin(np.arange(0.5, wsize + 0.5) / wsize * np.pi)
|
|
# win = win / norm(win);
|
|
|
|
# Pre-processing for edges
|
|
swin = np.zeros(T + wsize - tstep, dtype=np.float64)
|
|
for t in range(n_step):
|
|
swin[t * tstep : t * tstep + wsize] += win**2
|
|
swin = np.sqrt(swin / wsize)
|
|
|
|
for t in range(n_step):
|
|
# IFFT
|
|
frame = irfft(X[..., t], wsize)
|
|
# Overlap-add
|
|
frame *= win / swin[t * tstep : t * tstep + wsize]
|
|
x[..., t * tstep : t * tstep + wsize] += frame
|
|
|
|
# Truncation
|
|
x = x[..., (wsize - tstep) // 2 : (wsize - tstep) // 2 + T + 1]
|
|
x = x[..., :Tx].copy()
|
|
return x
|
|
|
|
|
|
def stftfreq(wsize, sfreq=None): # noqa: D401
|
|
"""Compute frequencies of stft transformation.
|
|
|
|
Parameters
|
|
----------
|
|
wsize : int
|
|
Size of stft window.
|
|
sfreq : float
|
|
Sampling frequency. If None the frequencies are given between 0 and pi
|
|
otherwise it's given in Hz.
|
|
|
|
Returns
|
|
-------
|
|
freqs : array
|
|
The positive frequencies returned by stft.
|
|
|
|
See Also
|
|
--------
|
|
stft
|
|
istft
|
|
"""
|
|
freqs = rfftfreq(wsize)
|
|
if sfreq is not None:
|
|
freqs *= float(sfreq)
|
|
return freqs
|
|
|
|
|
|
def stft_norm2(X):
|
|
"""Compute L2 norm of STFT transform.
|
|
|
|
It takes into account that stft only return positive frequencies.
|
|
As we use tight frame this quantity is conserved by the stft.
|
|
|
|
Parameters
|
|
----------
|
|
X : 3D complex array
|
|
The STFT transforms
|
|
|
|
Returns
|
|
-------
|
|
norms2 : array
|
|
The squared L2 norm of every row of X.
|
|
"""
|
|
X2 = (X * X.conj()).real
|
|
# compute all L2 coefs and remove first and last frequency once.
|
|
norms2 = (
|
|
2.0 * X2.sum(axis=2).sum(axis=1)
|
|
- np.sum(X2[:, 0, :], axis=1)
|
|
- np.sum(X2[:, -1, :], axis=1)
|
|
)
|
|
return norms2
|
|
|
|
|
|
def stft_norm1(X):
|
|
"""Compute L1 norm of STFT transform.
|
|
|
|
It takes into account that stft only return positive frequencies.
|
|
|
|
Parameters
|
|
----------
|
|
X : 3D complex array
|
|
The STFT transforms
|
|
|
|
Returns
|
|
-------
|
|
norms : array
|
|
The L1 norm of every row of X.
|
|
"""
|
|
X_abs = np.abs(X)
|
|
# compute all L1 coefs and remove first and last frequency once.
|
|
norms = (
|
|
2.0 * X_abs.sum(axis=(1, 2))
|
|
- np.sum(X_abs[:, 0, :], axis=1)
|
|
- np.sum(X_abs[:, -1, :], axis=1)
|
|
)
|
|
return norms
|