mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Added bindings and removed python implementation
This commit is contained in:
parent
b4f01a8f7d
commit
faf76a8133
@ -1,151 +0,0 @@
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def stft(
|
||||
x: mx.array,
|
||||
n_fft: int = 2048,
|
||||
hop_length: int = None,
|
||||
win_length: int = None,
|
||||
window: mx.array = None,
|
||||
center: bool = True,
|
||||
pad_mode: str = "reflect",
|
||||
normalized: bool = False,
|
||||
onesided: bool = True,
|
||||
return_complex: bool = True,
|
||||
) -> mx.array:
|
||||
"""
|
||||
Computes the Short-Time Fourier Transform (STFT) of a signal.
|
||||
|
||||
Args:
|
||||
x (mx.array): Input signal array.
|
||||
n_fft (int): Number of FFT points. Default is 2048.
|
||||
hop_length (int, optional): Number of samples between successive frames.
|
||||
Defaults to `n_fft // 4`.
|
||||
win_length (int, optional): Window size. Defaults to `n_fft`.
|
||||
window (mx.array, optional): Window function. Defaults to a rectangular window.
|
||||
center (bool): If True, pads the input signal so that the frame is centered. Default is True.
|
||||
pad_mode (str): Padding mode to use if `center` is True. Default is "reflect".
|
||||
normalized (bool): If True, normalizes the STFT by the square root of `n_fft`. Default is False.
|
||||
onesided (bool): If True, returns only the positive frequencies. Default is True.
|
||||
return_complex (bool): If True, returns a complex-valued STFT. If False, returns real and imaginary parts separately. Default is True.
|
||||
|
||||
Returns:
|
||||
mx.array: The STFT of the input signal. The shape depends on the input and parameters.
|
||||
"""
|
||||
|
||||
if hop_length is None:
|
||||
hop_length = n_fft // 4
|
||||
|
||||
if win_length is None:
|
||||
win_length = n_fft
|
||||
|
||||
if window is None:
|
||||
window = mx.ones(win_length)
|
||||
|
||||
if win_length < n_fft:
|
||||
pad_left = (n_fft - win_length) // 2
|
||||
pad_right = n_fft - win_length - pad_left
|
||||
window = mx.pad(window, [(pad_left, pad_right)])
|
||||
|
||||
if center:
|
||||
pad_width = n_fft // 2
|
||||
x = mx.pad(x, [(pad_width, pad_width)], mode=pad_mode)
|
||||
|
||||
n_frames = 1 + (x.shape[0] - n_fft) // hop_length
|
||||
|
||||
frames = mx.stack(
|
||||
[x[i * hop_length : i * hop_length + n_fft] * window for i in range(n_frames)]
|
||||
)
|
||||
|
||||
stft = mx.fft.fft(frames, n=n_fft, axis=-1)
|
||||
|
||||
if normalized:
|
||||
stft = stft / mx.sqrt(n_fft)
|
||||
|
||||
if onesided:
|
||||
stft = stft[..., : n_fft // 2 + 1]
|
||||
|
||||
if not return_complex:
|
||||
stft = mx.stack([stft.real, stft.imag], axis=-1)
|
||||
|
||||
return stft
|
||||
|
||||
|
||||
def istft(
|
||||
stft_matrix: mx.array,
|
||||
hop_length: int = None,
|
||||
win_length: int = None,
|
||||
window: mx.array = None,
|
||||
center: bool = True,
|
||||
length: int = None,
|
||||
normalized: bool = False,
|
||||
onesided: bool = True,
|
||||
) -> mx.array:
|
||||
"""
|
||||
Computes the inverse Short-Time Fourier Transform (iSTFT) of a signal.
|
||||
|
||||
Args:
|
||||
stft_matrix (mx.array): STFT matrix (output of `stft`).
|
||||
hop_length (int, optional): Number of samples between successive frames.
|
||||
Defaults to `stft_matrix.shape[-2] // 4`.
|
||||
win_length (int, optional): Window size. Defaults to `stft_matrix.shape[-2]`.
|
||||
window (mx.array, optional): Window function. Defaults to a rectangular window.
|
||||
center (bool): If True, removes padding added during STFT computation. Default is True.
|
||||
length (int, optional): Length of the output signal. If provided, the output is trimmed or zero-padded to this length.
|
||||
normalized (bool): If True, normalizes the iSTFT by the square root of the FFT size. Default is False.
|
||||
onesided (bool): If True, assumes the input STFT matrix is onesided. Default is True.
|
||||
|
||||
Returns:
|
||||
mx.array: The reconstructed time-domain signal.
|
||||
"""
|
||||
if hop_length is None:
|
||||
hop_length = stft_matrix.shape[-2] // 4
|
||||
|
||||
if win_length is None:
|
||||
win_length = stft_matrix.shape[-2]
|
||||
|
||||
if window is None:
|
||||
window = mx.ones(win_length)
|
||||
|
||||
if win_length < stft_matrix.shape[-2]:
|
||||
pad_left = (stft_matrix.shape[-2] - win_length) // 2
|
||||
pad_right = stft_matrix.shape[-2] - win_length - pad_left
|
||||
window = mx.pad(window, [(pad_left, pad_right)])
|
||||
|
||||
if stft_matrix.shape[-1] == 2:
|
||||
stft_matrix = stft_matrix[..., 0] + 1j * stft_matrix[..., 1]
|
||||
|
||||
if onesided:
|
||||
n_fft = 2 * (stft_matrix.shape[-1] - 1)
|
||||
full_stft = mx.zeros((*stft_matrix.shape[:-1], n_fft), dtype=stft_matrix.dtype)
|
||||
full_stft[..., : stft_matrix.shape[-1]] = stft_matrix
|
||||
full_stft[..., stft_matrix.shape[-1] :] = mx.conj(stft_matrix[..., -2:0:-1])
|
||||
stft_matrix = full_stft
|
||||
|
||||
frames = mx.fft.ifft(stft_matrix, n=stft_matrix.shape[-1], axis=-1)
|
||||
|
||||
if normalized:
|
||||
frames = frames * mx.sqrt(frames.shape[-1])
|
||||
|
||||
frames = frames * window
|
||||
|
||||
signal_length = (frames.shape[0] - 1) * hop_length + frames.shape[1]
|
||||
signal = mx.zeros(signal_length, dtype=frames.dtype)
|
||||
|
||||
for i in range(frames.shape[0]):
|
||||
signal[i * hop_length : i * hop_length + frames.shape[1]] += frames[i]
|
||||
|
||||
window_sum = mx.zeros(signal_length, dtype=frames.dtype)
|
||||
for i in range(frames.shape[0]):
|
||||
window_sum[i * hop_length : i * hop_length + frames.shape[1]] += window
|
||||
|
||||
signal = signal / window_sum
|
||||
|
||||
if center:
|
||||
pad_width = frames.shape[1] // 2
|
||||
signal = signal[pad_width:-pad_width]
|
||||
|
||||
if length is not None:
|
||||
signal = signal[:length]
|
||||
|
||||
return signal
|
@ -459,4 +459,100 @@ void init_fft(nb::module_& parent_module) {
|
||||
Returns:
|
||||
array: The real array containing the inverse of :func:`rfftn`.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"stft",
|
||||
[](const mx::array& x,
|
||||
int n_fft = 2048,
|
||||
int hop_length = -1,
|
||||
int win_length = -1,
|
||||
const mx::array& window = mx::array(),
|
||||
bool center = true,
|
||||
const std::string& pad_mode = "reflect",
|
||||
bool normalized = false,
|
||||
bool onesided = true,
|
||||
mx::StreamOrDevice s = {}) {
|
||||
return mx::stft::stft(
|
||||
x,
|
||||
n_fft,
|
||||
hop_length,
|
||||
win_length,
|
||||
window,
|
||||
center,
|
||||
pad_mode,
|
||||
normalized,
|
||||
onesided,
|
||||
s);
|
||||
},
|
||||
"x"_a,
|
||||
"n_fft"_a = 2048,
|
||||
"hop_length"_a = -1,
|
||||
"win_length"_a = -1,
|
||||
"window"_a = mx::array(),
|
||||
"center"_a = true,
|
||||
"pad_mode"_a = "reflect",
|
||||
"normalized"_a = false,
|
||||
"onesided"_a = true,
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Short-time Fourier Transform (STFT).
|
||||
|
||||
Args:
|
||||
x (array): Input signal.
|
||||
n_fft (int, optional): Number of FFT points. Default is 2048.
|
||||
hop_length (int, optional): Number of samples between successive frames. Default is `n_fft // 4`.
|
||||
win_length (int, optional): Window size. Default is `n_fft`.
|
||||
window (array, optional): Window function. Default is a rectangular window.
|
||||
center (bool, optional): Whether to pad the signal to center the frames. Default is True.
|
||||
pad_mode (str, optional): Padding mode. Default is "reflect".
|
||||
normalized (bool, optional): Whether to normalize the STFT. Default is False.
|
||||
onesided (bool, optional): Whether to return a one-sided STFT. Default is True.
|
||||
|
||||
Returns:
|
||||
array: The STFT of the input signal.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"istft",
|
||||
[](const mx::array& stft_matrix,
|
||||
int hop_length = -1,
|
||||
int win_length = -1,
|
||||
const mx::array& window = mx::array(),
|
||||
bool center = true,
|
||||
int length = -1,
|
||||
bool normalized = false,
|
||||
mx::StreamOrDevice s = {}) {
|
||||
return mx::stft::istft(
|
||||
stft_matrix,
|
||||
hop_length,
|
||||
win_length,
|
||||
window,
|
||||
center,
|
||||
length,
|
||||
normalized,
|
||||
s);
|
||||
},
|
||||
"stft_matrix"_a,
|
||||
"hop_length"_a = -1,
|
||||
"win_length"_a = -1,
|
||||
"window"_a = mx::array(),
|
||||
"center"_a = true,
|
||||
"length"_a = -1,
|
||||
"normalized"_a = false,
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Inverse Short-time Fourier Transform (ISTFT).
|
||||
|
||||
Args:
|
||||
stft_matrix (array): Input STFT matrix.
|
||||
hop_length (int, optional): Number of samples between successive frames. Default is `n_fft // 4`.
|
||||
win_length (int, optional): Window size. Default is `n_fft`.
|
||||
window (array, optional): Window function. Default is a rectangular window.
|
||||
center (bool, optional): Whether the signal was padded to center the frames. Default is True.
|
||||
length (int, optional): Length of the output signal. Default is inferred from the STFT matrix.
|
||||
normalized (bool, optional): Whether the STFT was normalized. Default is False.
|
||||
|
||||
Returns:
|
||||
array: The reconstructed signal.
|
||||
)pbdoc");
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user