From 9e8befbe8d5431d485310501690c7743461d9cb7 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Sun, 13 Apr 2025 08:55:56 +0530 Subject: [PATCH 1/9] Added Short time fourier transform and ISTFT implementations --- python/mlx/stft.py | 115 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 python/mlx/stft.py diff --git a/python/mlx/stft.py b/python/mlx/stft.py new file mode 100644 index 000000000..fdd6148ce --- /dev/null +++ b/python/mlx/stft.py @@ -0,0 +1,115 @@ +import mlx.core as mx +import numpy as np + +def stft( + x: mx.array, + n_fft: int = 2048, + hop_length: int = None, + win_length: int = None, + window: mx.array = None, + center: bool = True, + pad_mode: str = "reflect", + normalized: bool = False, + onesided: bool = True, + return_complex: bool = True, +) -> mx.array: + + if hop_length is None: + hop_length = n_fft // 4 + + if win_length is None: + win_length = n_fft + + if window is None: + window = mx.ones(win_length) + + if win_length < n_fft: + pad_left = (n_fft - win_length) // 2 + pad_right = n_fft - win_length - pad_left + window = mx.pad(window, [(pad_left, pad_right)]) + + if center: + pad_width = n_fft // 2 + x = mx.pad(x, [(pad_width, pad_width)], mode=pad_mode) + + n_frames = 1 + (x.shape[0] - n_fft) // hop_length + + frames = mx.stack([ + x[i * hop_length : i * hop_length + n_fft] * window + for i in range(n_frames) + ]) + + stft = mx.fft.fft(frames, n=n_fft, axis=-1) + + if normalized: + stft = stft / mx.sqrt(n_fft) + + if onesided: + stft = stft[..., :n_fft//2 + 1] + + if not return_complex: + stft = mx.stack([stft.real, stft.imag], axis=-1) + + return stft + +def istft( + stft_matrix: mx.array, + hop_length: int = None, + win_length: int = None, + window: mx.array = None, + center: bool = True, + length: int = None, + normalized: bool = False, + onesided: bool = True, +) -> mx.array: + if hop_length is None: + hop_length = stft_matrix.shape[-2] // 4 + + if win_length is None: + win_length = stft_matrix.shape[-2] + + if window is None: + window = mx.ones(win_length) + + if win_length < stft_matrix.shape[-2]: + pad_left = (stft_matrix.shape[-2] - win_length) // 2 + pad_right = stft_matrix.shape[-2] - win_length - pad_left + window = mx.pad(window, [(pad_left, pad_right)]) + + if stft_matrix.shape[-1] == 2: + stft_matrix = stft_matrix[..., 0] + 1j * stft_matrix[..., 1] + + if onesided: + n_fft = 2 * (stft_matrix.shape[-1] - 1) + full_stft = mx.zeros((*stft_matrix.shape[:-1], n_fft), dtype=stft_matrix.dtype) + full_stft[..., :stft_matrix.shape[-1]] = stft_matrix + full_stft[..., stft_matrix.shape[-1]:] = mx.conj(stft_matrix[..., -2:0:-1]) + stft_matrix = full_stft + + frames = mx.fft.ifft(stft_matrix, n=stft_matrix.shape[-1], axis=-1) + + if normalized: + frames = frames * mx.sqrt(frames.shape[-1]) + + frames = frames * window + + signal_length = (frames.shape[0] - 1) * hop_length + frames.shape[1] + signal = mx.zeros(signal_length, dtype=frames.dtype) + + for i in range(frames.shape[0]): + signal[i * hop_length : i * hop_length + frames.shape[1]] += frames[i] + + window_sum = mx.zeros(signal_length, dtype=frames.dtype) + for i in range(frames.shape[0]): + window_sum[i * hop_length : i * hop_length + frames.shape[1]] += window + + signal = signal / window_sum + + if center: + pad_width = frames.shape[1] // 2 + signal = signal[pad_width:-pad_width] + + if length is not None: + signal = signal[:length] + + return signal \ No newline at end of file From f50cce83a510f49722c4090ebafac61a037e1da5 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Thu, 17 Apr 2025 00:24:49 +0530 Subject: [PATCH 2/9] Formatted --- python/mlx/stft.py | 59 +++++++++++++++++++++++----------------------- 1 file changed, 30 insertions(+), 29 deletions(-) diff --git a/python/mlx/stft.py b/python/mlx/stft.py index fdd6148ce..ac2cbf349 100644 --- a/python/mlx/stft.py +++ b/python/mlx/stft.py @@ -1,6 +1,7 @@ import mlx.core as mx import numpy as np + def stft( x: mx.array, n_fft: int = 2048, @@ -16,10 +17,10 @@ def stft( if hop_length is None: hop_length = n_fft // 4 - + if win_length is None: win_length = n_fft - + if window is None: window = mx.ones(win_length) @@ -34,24 +35,24 @@ def stft( n_frames = 1 + (x.shape[0] - n_fft) // hop_length - frames = mx.stack([ - x[i * hop_length : i * hop_length + n_fft] * window - for i in range(n_frames) - ]) + frames = mx.stack( + [x[i * hop_length : i * hop_length + n_fft] * window for i in range(n_frames)] + ) stft = mx.fft.fft(frames, n=n_fft, axis=-1) - + if normalized: stft = stft / mx.sqrt(n_fft) - + if onesided: - stft = stft[..., :n_fft//2 + 1] - + stft = stft[..., : n_fft // 2 + 1] + if not return_complex: stft = mx.stack([stft.real, stft.imag], axis=-1) - + return stft + def istft( stft_matrix: mx.array, hop_length: int = None, @@ -64,52 +65,52 @@ def istft( ) -> mx.array: if hop_length is None: hop_length = stft_matrix.shape[-2] // 4 - + if win_length is None: win_length = stft_matrix.shape[-2] - + if window is None: window = mx.ones(win_length) - + if win_length < stft_matrix.shape[-2]: pad_left = (stft_matrix.shape[-2] - win_length) // 2 pad_right = stft_matrix.shape[-2] - win_length - pad_left window = mx.pad(window, [(pad_left, pad_right)]) - + if stft_matrix.shape[-1] == 2: stft_matrix = stft_matrix[..., 0] + 1j * stft_matrix[..., 1] - + if onesided: n_fft = 2 * (stft_matrix.shape[-1] - 1) full_stft = mx.zeros((*stft_matrix.shape[:-1], n_fft), dtype=stft_matrix.dtype) - full_stft[..., :stft_matrix.shape[-1]] = stft_matrix - full_stft[..., stft_matrix.shape[-1]:] = mx.conj(stft_matrix[..., -2:0:-1]) + full_stft[..., : stft_matrix.shape[-1]] = stft_matrix + full_stft[..., stft_matrix.shape[-1] :] = mx.conj(stft_matrix[..., -2:0:-1]) stft_matrix = full_stft - + frames = mx.fft.ifft(stft_matrix, n=stft_matrix.shape[-1], axis=-1) - + if normalized: frames = frames * mx.sqrt(frames.shape[-1]) - + frames = frames * window - + signal_length = (frames.shape[0] - 1) * hop_length + frames.shape[1] signal = mx.zeros(signal_length, dtype=frames.dtype) - + for i in range(frames.shape[0]): signal[i * hop_length : i * hop_length + frames.shape[1]] += frames[i] - + window_sum = mx.zeros(signal_length, dtype=frames.dtype) for i in range(frames.shape[0]): window_sum[i * hop_length : i * hop_length + frames.shape[1]] += window - + signal = signal / window_sum - + if center: pad_width = frames.shape[1] // 2 signal = signal[pad_width:-pad_width] - + if length is not None: signal = signal[:length] - - return signal \ No newline at end of file + + return signal From 91097e61795dac61bb6d83447c5bcea667bfb6fc Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Thu, 17 Apr 2025 08:20:32 +0530 Subject: [PATCH 3/9] Added docstrings --- python/mlx/stft.py | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/python/mlx/stft.py b/python/mlx/stft.py index ac2cbf349..a4abfdc87 100644 --- a/python/mlx/stft.py +++ b/python/mlx/stft.py @@ -1,5 +1,4 @@ import mlx.core as mx -import numpy as np def stft( @@ -14,6 +13,25 @@ def stft( onesided: bool = True, return_complex: bool = True, ) -> mx.array: + """ + Computes the Short-Time Fourier Transform (STFT) of a signal. + + Args: + x (mx.array): Input signal array. + n_fft (int): Number of FFT points. Default is 2048. + hop_length (int, optional): Number of samples between successive frames. + Defaults to `n_fft // 4`. + win_length (int, optional): Window size. Defaults to `n_fft`. + window (mx.array, optional): Window function. Defaults to a rectangular window. + center (bool): If True, pads the input signal so that the frame is centered. Default is True. + pad_mode (str): Padding mode to use if `center` is True. Default is "reflect". + normalized (bool): If True, normalizes the STFT by the square root of `n_fft`. Default is False. + onesided (bool): If True, returns only the positive frequencies. Default is True. + return_complex (bool): If True, returns a complex-valued STFT. If False, returns real and imaginary parts separately. Default is True. + + Returns: + mx.array: The STFT of the input signal. The shape depends on the input and parameters. + """ if hop_length is None: hop_length = n_fft // 4 @@ -63,6 +81,23 @@ def istft( normalized: bool = False, onesided: bool = True, ) -> mx.array: + """ + Computes the inverse Short-Time Fourier Transform (iSTFT) of a signal. + + Args: + stft_matrix (mx.array): STFT matrix (output of `stft`). + hop_length (int, optional): Number of samples between successive frames. + Defaults to `stft_matrix.shape[-2] // 4`. + win_length (int, optional): Window size. Defaults to `stft_matrix.shape[-2]`. + window (mx.array, optional): Window function. Defaults to a rectangular window. + center (bool): If True, removes padding added during STFT computation. Default is True. + length (int, optional): Length of the output signal. If provided, the output is trimmed or zero-padded to this length. + normalized (bool): If True, normalizes the iSTFT by the square root of the FFT size. Default is False. + onesided (bool): If True, assumes the input STFT matrix is onesided. Default is True. + + Returns: + mx.array: The reconstructed time-domain signal. + """ if hop_length is None: hop_length = stft_matrix.shape[-2] // 4 From b4f01a8f7dbeabb7406298a345e6dd4869669818 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Thu, 17 Apr 2025 22:27:46 +0530 Subject: [PATCH 4/9] Added cpp implementation of stft --- mlx/fft.cpp | 132 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 132 insertions(+) diff --git a/mlx/fft.cpp b/mlx/fft.cpp index f0d41bf0f..dfb585dfc 100644 --- a/mlx/fft.cpp +++ b/mlx/fft.cpp @@ -1,7 +1,9 @@ // Copyright © 2023 Apple Inc. +#include #include #include +#include #include "mlx/fft.h" #include "mlx/ops.h" @@ -190,3 +192,133 @@ array irfftn(const array& a, StreamOrDevice s /* = {} */) { } } // namespace mlx::core::fft + +namespace mlx::core::stft { + +array stft( + const array& x, + int n_fft = 2048, + int hop_length = -1, + int win_length = -1, + const array& window, + bool center = true, + const std::string& pad_mode = "reflect", + bool normalized = false, + bool onesided = true, + StreamOrDevice s /* = {} */) { + if (hop_length == -1) + hop_length = n_fft / 4; + if (win_length == -1) + win_length = n_fft; + + array win = (window.size() == 0) ? ones({win_length}, float32, s) : window; + + if (win_length < n_fft) { + int pad_left = (n_fft - win_length) / 2; + int pad_right = n_fft - win_length - pad_left; + + array left_pad = zeros({pad_left}, float32, s); + array right_pad = zeros({pad_right}, float32, s); + win = concatenate({left_pad, win, right_pad}, 0, s); + } + + array padded_x = x; + if (center) { + int pad_width = n_fft / 2; + + array left_pad = zeros({pad_width}, x.dtype(), s); + array right_pad = zeros({pad_width}, x.dtype(), s); + padded_x = concatenate({left_pad, padded_x, right_pad}, 0, s); + } + + int n_frames = 1 + (padded_x.shape(0) - n_fft) / hop_length; + + std::vector frames; + for (int i = 0; i < n_frames; ++i) { + array frame = + slice(padded_x, {i * hop_length}, {i * hop_length + n_fft}, s); + frames.push_back(multiply(frame, win, s)); + } + + array stacked_frames = stack(frames, 0, s); + + array stft_result = mlx::core::fft::rfftn(stacked_frames, {n_fft}, {-1}, s); + + if (normalized) { + array n_fft_array = full({1}, static_cast(n_fft), float32, s); + stft_result = divide(stft_result, sqrt(n_fft_array, s), s); + } + + if (onesided) { + stft_result = slice(stft_result, {}, {n_fft / 2 + 1}, s); + } + + return stft_result; +} + +array istft( + const array& stft_matrix, + int hop_length = -1, + int win_length = -1, + const array& window, + bool center = true, + int length = -1, + bool normalized = false, + StreamOrDevice s /* = {} */) { + int n_fft = (stft_matrix.shape(-1) - 1) * 2; + if (hop_length == -1) + hop_length = n_fft / 4; + if (win_length == -1) + win_length = n_fft; + + array win = (window.size() == 0) ? ones({win_length}, float32, s) : window; + + if (win_length < n_fft) { + int pad_left = (n_fft - win_length) / 2; + int pad_right = n_fft - win_length - pad_left; + + array left_pad = zeros({pad_left}, float32, s); + array right_pad = zeros({pad_right}, float32, s); + win = concatenate({left_pad, win, right_pad}, 0, s); + } + + array frames = mlx::core::fft::irfftn(stft_matrix, {n_fft}, {-1}, s); + + frames = multiply(frames, win, s); + + int signal_length = (frames.shape(0) - 1) * hop_length + n_fft; + array signal = zeros({signal_length}, float32, s); + array window_sum = zeros({signal_length}, float32, s); + + for (int i = 0; i < frames.shape(0); ++i) { + array frame = reshape(slice(frames, {i}, {i + 1}, s), {n_fft}, s); + array signal_slice = + slice(signal, {i * hop_length}, {i * hop_length + n_fft}, s); + array window_slice = + slice(window_sum, {i * hop_length}, {i * hop_length + n_fft}, s); + + signal_slice = add(signal_slice, frame, s); + window_slice = add(window_slice, win, s); + } + + signal = divide(signal, window_sum, s); + + if (center) { + int pad_width = n_fft / 2; + signal = slice(signal, {pad_width}, {signal.shape(0) - pad_width}, s); + } + + if (length > 0) { + if (signal.shape(0) > length) { + signal = slice(signal, {0}, {length}, s); + } else if (signal.shape(0) < length) { + int pad_length = length - signal.shape(0); + array pad_array = zeros({pad_length}, signal.dtype(), s); + signal = concatenate({signal, pad_array}, 0, s); + } + } + + return signal; +} + +} // namespace mlx::core::stft From faf76a8133361e8eea54bcc7875b32d460ead358 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Fri, 18 Apr 2025 01:45:06 +0530 Subject: [PATCH 5/9] Added bindings and removed python implementation --- python/mlx/stft.py | 151 --------------------------------------------- python/src/fft.cpp | 96 ++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 151 deletions(-) delete mode 100644 python/mlx/stft.py diff --git a/python/mlx/stft.py b/python/mlx/stft.py deleted file mode 100644 index a4abfdc87..000000000 --- a/python/mlx/stft.py +++ /dev/null @@ -1,151 +0,0 @@ -import mlx.core as mx - - -def stft( - x: mx.array, - n_fft: int = 2048, - hop_length: int = None, - win_length: int = None, - window: mx.array = None, - center: bool = True, - pad_mode: str = "reflect", - normalized: bool = False, - onesided: bool = True, - return_complex: bool = True, -) -> mx.array: - """ - Computes the Short-Time Fourier Transform (STFT) of a signal. - - Args: - x (mx.array): Input signal array. - n_fft (int): Number of FFT points. Default is 2048. - hop_length (int, optional): Number of samples between successive frames. - Defaults to `n_fft // 4`. - win_length (int, optional): Window size. Defaults to `n_fft`. - window (mx.array, optional): Window function. Defaults to a rectangular window. - center (bool): If True, pads the input signal so that the frame is centered. Default is True. - pad_mode (str): Padding mode to use if `center` is True. Default is "reflect". - normalized (bool): If True, normalizes the STFT by the square root of `n_fft`. Default is False. - onesided (bool): If True, returns only the positive frequencies. Default is True. - return_complex (bool): If True, returns a complex-valued STFT. If False, returns real and imaginary parts separately. Default is True. - - Returns: - mx.array: The STFT of the input signal. The shape depends on the input and parameters. - """ - - if hop_length is None: - hop_length = n_fft // 4 - - if win_length is None: - win_length = n_fft - - if window is None: - window = mx.ones(win_length) - - if win_length < n_fft: - pad_left = (n_fft - win_length) // 2 - pad_right = n_fft - win_length - pad_left - window = mx.pad(window, [(pad_left, pad_right)]) - - if center: - pad_width = n_fft // 2 - x = mx.pad(x, [(pad_width, pad_width)], mode=pad_mode) - - n_frames = 1 + (x.shape[0] - n_fft) // hop_length - - frames = mx.stack( - [x[i * hop_length : i * hop_length + n_fft] * window for i in range(n_frames)] - ) - - stft = mx.fft.fft(frames, n=n_fft, axis=-1) - - if normalized: - stft = stft / mx.sqrt(n_fft) - - if onesided: - stft = stft[..., : n_fft // 2 + 1] - - if not return_complex: - stft = mx.stack([stft.real, stft.imag], axis=-1) - - return stft - - -def istft( - stft_matrix: mx.array, - hop_length: int = None, - win_length: int = None, - window: mx.array = None, - center: bool = True, - length: int = None, - normalized: bool = False, - onesided: bool = True, -) -> mx.array: - """ - Computes the inverse Short-Time Fourier Transform (iSTFT) of a signal. - - Args: - stft_matrix (mx.array): STFT matrix (output of `stft`). - hop_length (int, optional): Number of samples between successive frames. - Defaults to `stft_matrix.shape[-2] // 4`. - win_length (int, optional): Window size. Defaults to `stft_matrix.shape[-2]`. - window (mx.array, optional): Window function. Defaults to a rectangular window. - center (bool): If True, removes padding added during STFT computation. Default is True. - length (int, optional): Length of the output signal. If provided, the output is trimmed or zero-padded to this length. - normalized (bool): If True, normalizes the iSTFT by the square root of the FFT size. Default is False. - onesided (bool): If True, assumes the input STFT matrix is onesided. Default is True. - - Returns: - mx.array: The reconstructed time-domain signal. - """ - if hop_length is None: - hop_length = stft_matrix.shape[-2] // 4 - - if win_length is None: - win_length = stft_matrix.shape[-2] - - if window is None: - window = mx.ones(win_length) - - if win_length < stft_matrix.shape[-2]: - pad_left = (stft_matrix.shape[-2] - win_length) // 2 - pad_right = stft_matrix.shape[-2] - win_length - pad_left - window = mx.pad(window, [(pad_left, pad_right)]) - - if stft_matrix.shape[-1] == 2: - stft_matrix = stft_matrix[..., 0] + 1j * stft_matrix[..., 1] - - if onesided: - n_fft = 2 * (stft_matrix.shape[-1] - 1) - full_stft = mx.zeros((*stft_matrix.shape[:-1], n_fft), dtype=stft_matrix.dtype) - full_stft[..., : stft_matrix.shape[-1]] = stft_matrix - full_stft[..., stft_matrix.shape[-1] :] = mx.conj(stft_matrix[..., -2:0:-1]) - stft_matrix = full_stft - - frames = mx.fft.ifft(stft_matrix, n=stft_matrix.shape[-1], axis=-1) - - if normalized: - frames = frames * mx.sqrt(frames.shape[-1]) - - frames = frames * window - - signal_length = (frames.shape[0] - 1) * hop_length + frames.shape[1] - signal = mx.zeros(signal_length, dtype=frames.dtype) - - for i in range(frames.shape[0]): - signal[i * hop_length : i * hop_length + frames.shape[1]] += frames[i] - - window_sum = mx.zeros(signal_length, dtype=frames.dtype) - for i in range(frames.shape[0]): - window_sum[i * hop_length : i * hop_length + frames.shape[1]] += window - - signal = signal / window_sum - - if center: - pad_width = frames.shape[1] // 2 - signal = signal[pad_width:-pad_width] - - if length is not None: - signal = signal[:length] - - return signal diff --git a/python/src/fft.cpp b/python/src/fft.cpp index 5ad4702e2..675784680 100644 --- a/python/src/fft.cpp +++ b/python/src/fft.cpp @@ -459,4 +459,100 @@ void init_fft(nb::module_& parent_module) { Returns: array: The real array containing the inverse of :func:`rfftn`. )pbdoc"); + + m.def( + "stft", + [](const mx::array& x, + int n_fft = 2048, + int hop_length = -1, + int win_length = -1, + const mx::array& window = mx::array(), + bool center = true, + const std::string& pad_mode = "reflect", + bool normalized = false, + bool onesided = true, + mx::StreamOrDevice s = {}) { + return mx::stft::stft( + x, + n_fft, + hop_length, + win_length, + window, + center, + pad_mode, + normalized, + onesided, + s); + }, + "x"_a, + "n_fft"_a = 2048, + "hop_length"_a = -1, + "win_length"_a = -1, + "window"_a = mx::array(), + "center"_a = true, + "pad_mode"_a = "reflect", + "normalized"_a = false, + "onesided"_a = true, + "stream"_a = nb::none(), + R"pbdoc( + Short-time Fourier Transform (STFT). + + Args: + x (array): Input signal. + n_fft (int, optional): Number of FFT points. Default is 2048. + hop_length (int, optional): Number of samples between successive frames. Default is `n_fft // 4`. + win_length (int, optional): Window size. Default is `n_fft`. + window (array, optional): Window function. Default is a rectangular window. + center (bool, optional): Whether to pad the signal to center the frames. Default is True. + pad_mode (str, optional): Padding mode. Default is "reflect". + normalized (bool, optional): Whether to normalize the STFT. Default is False. + onesided (bool, optional): Whether to return a one-sided STFT. Default is True. + + Returns: + array: The STFT of the input signal. + )pbdoc"); + + m.def( + "istft", + [](const mx::array& stft_matrix, + int hop_length = -1, + int win_length = -1, + const mx::array& window = mx::array(), + bool center = true, + int length = -1, + bool normalized = false, + mx::StreamOrDevice s = {}) { + return mx::stft::istft( + stft_matrix, + hop_length, + win_length, + window, + center, + length, + normalized, + s); + }, + "stft_matrix"_a, + "hop_length"_a = -1, + "win_length"_a = -1, + "window"_a = mx::array(), + "center"_a = true, + "length"_a = -1, + "normalized"_a = false, + "stream"_a = nb::none(), + R"pbdoc( + Inverse Short-time Fourier Transform (ISTFT). + + Args: + stft_matrix (array): Input STFT matrix. + hop_length (int, optional): Number of samples between successive frames. Default is `n_fft // 4`. + win_length (int, optional): Window size. Default is `n_fft`. + window (array, optional): Window function. Default is a rectangular window. + center (bool, optional): Whether the signal was padded to center the frames. Default is True. + length (int, optional): Length of the output signal. Default is inferred from the STFT matrix. + normalized (bool, optional): Whether the STFT was normalized. Default is False. + + Returns: + array: The reconstructed signal. + )pbdoc"); } From c92a2bc6791c234f1b63a2915110bab6eeaea988 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Fri, 18 Apr 2025 23:02:55 +0530 Subject: [PATCH 6/9] changed to mx::pad and added stft to a single fft namespace --- mlx/fft.cpp | 28 +++++++++------------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/mlx/fft.cpp b/mlx/fft.cpp index dfb585dfc..a8af2eea8 100644 --- a/mlx/fft.cpp +++ b/mlx/fft.cpp @@ -191,10 +191,6 @@ array irfftn(const array& a, StreamOrDevice s /* = {} */) { return fft_impl(a, true, true, s); } -} // namespace mlx::core::fft - -namespace mlx::core::stft { - array stft( const array& x, int n_fft = 2048, @@ -216,19 +212,15 @@ array stft( if (win_length < n_fft) { int pad_left = (n_fft - win_length) / 2; int pad_right = n_fft - win_length - pad_left; - - array left_pad = zeros({pad_left}, float32, s); - array right_pad = zeros({pad_right}, float32, s); - win = concatenate({left_pad, win, right_pad}, 0, s); + win = mlx::core::pad( + win, {{pad_left, pad_right}}, array(0, float32), "constant", s); } array padded_x = x; if (center) { int pad_width = n_fft / 2; - - array left_pad = zeros({pad_width}, x.dtype(), s); - array right_pad = zeros({pad_width}, x.dtype(), s); - padded_x = concatenate({left_pad, padded_x, right_pad}, 0, s); + padded_x = mlx::core::pad( + padded_x, {{pad_width, pad_width}}, array(0, x.dtype()), pad_mode, s); } int n_frames = 1 + (padded_x.shape(0) - n_fft) / hop_length; @@ -276,10 +268,8 @@ array istft( if (win_length < n_fft) { int pad_left = (n_fft - win_length) / 2; int pad_right = n_fft - win_length - pad_left; - - array left_pad = zeros({pad_left}, float32, s); - array right_pad = zeros({pad_right}, float32, s); - win = concatenate({left_pad, win, right_pad}, 0, s); + win = mlx::core::pad( + win, {{pad_left, pad_right}}, array(0, float32), "constant", s); } array frames = mlx::core::fft::irfftn(stft_matrix, {n_fft}, {-1}, s); @@ -313,12 +303,12 @@ array istft( signal = slice(signal, {0}, {length}, s); } else if (signal.shape(0) < length) { int pad_length = length - signal.shape(0); - array pad_array = zeros({pad_length}, signal.dtype(), s); - signal = concatenate({signal, pad_array}, 0, s); + signal = mlx::core::pad( + signal, {{0, pad_length}}, array(0, signal.dtype()), "constant", s); } } return signal; } -} // namespace mlx::core::stft +} // namespace mlx::core::fft \ No newline at end of file From 0db393acd737e787a2061d9e53f8ffa7fdb0d742 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Fri, 18 Apr 2025 23:21:10 +0530 Subject: [PATCH 7/9] Fixed fft.cpp and added tests --- mlx/fft.cpp | 34 ++++++++++----------- tests/fft_tests.cpp | 74 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 18 deletions(-) diff --git a/mlx/fft.cpp b/mlx/fft.cpp index a8af2eea8..68b032727 100644 --- a/mlx/fft.cpp +++ b/mlx/fft.cpp @@ -225,15 +225,13 @@ array stft( int n_frames = 1 + (padded_x.shape(0) - n_fft) / hop_length; - std::vector frames; - for (int i = 0; i < n_frames; ++i) { - array frame = - slice(padded_x, {i * hop_length}, {i * hop_length + n_fft}, s); - frames.push_back(multiply(frame, win, s)); - } - - array stacked_frames = stack(frames, 0, s); + Shape strided_shape = {n_frames, n_fft}; + Strides strided_strides = { + hop_length * static_cast(sizeof(float32)), + static_cast(sizeof(float32))}; + array frames = as_strided(padded_x, strided_shape, strided_strides, 0, s); + array stacked_frames = multiply(frames, win, s); array stft_result = mlx::core::fft::rfftn(stacked_frames, {n_fft}, {-1}, s); if (normalized) { @@ -273,23 +271,23 @@ array istft( } array frames = mlx::core::fft::irfftn(stft_matrix, {n_fft}, {-1}, s); - frames = multiply(frames, win, s); int signal_length = (frames.shape(0) - 1) * hop_length + n_fft; array signal = zeros({signal_length}, float32, s); array window_sum = zeros({signal_length}, float32, s); - for (int i = 0; i < frames.shape(0); ++i) { - array frame = reshape(slice(frames, {i}, {i + 1}, s), {n_fft}, s); - array signal_slice = - slice(signal, {i * hop_length}, {i * hop_length + n_fft}, s); - array window_slice = - slice(window_sum, {i * hop_length}, {i * hop_length + n_fft}, s); + Shape strided_shape = {frames.shape(0), n_fft}; + Strides strided_strides = { + hop_length * static_cast(sizeof(float32)), + static_cast(sizeof(float32))}; + array signal_strided = + as_strided(signal, strided_shape, strided_strides, 0, s); + array window_sum_strided = + as_strided(window_sum, strided_shape, strided_strides, 0, s); - signal_slice = add(signal_slice, frame, s); - window_slice = add(window_slice, win, s); - } + signal_strided = add(signal_strided, frames, s); + window_sum_strided = add(window_sum_strided, win, s); signal = divide(signal, window_sum, s); diff --git a/tests/fft_tests.cpp b/tests/fft_tests.cpp index c04dda1d5..b0d8c8e52 100644 --- a/tests/fft_tests.cpp +++ b/tests/fft_tests.cpp @@ -308,3 +308,77 @@ TEST_CASE("test fft grads") { .second; CHECK_EQ(vjp_out.shape(), Shape{5, 5}); } + +TEST_CASE("test stft and istft") { + int n_fft = 4; + int hop_length = 2; + int win_length = 4; + + array signal = array({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, float32); + + array window = array({0.5, 1.0, 1.0, 0.5}, float32); + + SUBCASE("stft basic functionality") { + auto stft_result = fft::stft(signal, n_fft, hop_length, win_length, window); + + CHECK_EQ(stft_result.shape(0), 4); + CHECK_EQ(stft_result.shape(1), 3); + + CHECK_EQ(stft_result.dtype(), complex64); + } + + SUBCASE("istft reconstruction") { + auto stft_result = fft::stft(signal, n_fft, hop_length, win_length, window); + auto reconstructed_signal = + fft::istft(stft_result, hop_length, win_length, window); + + CHECK_EQ(reconstructed_signal.shape(0), signal.shape(0)); + CHECK(allclose(signal, reconstructed_signal, 1e-5, 1e-5).item()); + } + + SUBCASE("stft with default parameters") { + auto stft_result = fft::stft(signal); + + CHECK_EQ(stft_result.shape(0), 5); + CHECK_EQ(stft_result.shape(1), 3); + + CHECK_EQ(stft_result.dtype(), complex64); + } + + SUBCASE("istft with length parameter") { + auto stft_result = fft::stft(signal, n_fft, hop_length, win_length, window); + int length = 6; + auto reconstructed_signal = + fft::istft(stft_result, hop_length, win_length, window, true, length); + + CHECK_EQ(reconstructed_signal.shape(0), length); + + CHECK( + allclose(slice(signal, {0}, {length}), reconstructed_signal, 1e-5, 1e-5) + .item()); + } + + SUBCASE("stft and istft with normalization") { + auto stft_result = fft::stft( + signal, n_fft, hop_length, win_length, window, true, "reflect", true); + auto reconstructed_signal = + fft::istft(stft_result, hop_length, win_length, window, true, -1, true); + + CHECK(allclose(signal, reconstructed_signal, 1e-5, 1e-5).item()); + } + + SUBCASE("stft with onesided=False") { + auto stft_result = fft::stft( + signal, + n_fft, + hop_length, + win_length, + window, + true, + "reflect", + false, + false); + + CHECK_EQ(stft_result.shape(1), n_fft); + } +} \ No newline at end of file From f7f323f6aeada6f345320d12ed3b4f84fde45502 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Wed, 23 Apr 2025 19:19:04 +0530 Subject: [PATCH 8/9] Added stft and istft in the header --- benchmarks/cpp/autograd.cpp | 2 +- mlx/fft.cpp | 1 - mlx/fft.h | 47 +++++++++++++++++++++++++++++++++++++ python/src/fft.cpp | 16 ++++++------- 4 files changed, 56 insertions(+), 10 deletions(-) diff --git a/benchmarks/cpp/autograd.cpp b/benchmarks/cpp/autograd.cpp index b4303a840..fcdf0c5d6 100644 --- a/benchmarks/cpp/autograd.cpp +++ b/benchmarks/cpp/autograd.cpp @@ -10,7 +10,7 @@ namespace mx = mlx::core; void time_value_and_grad() { auto x = mx::ones({200, 1000}); mx::eval(x); - auto fn = [](mx::array x) { + auto fn = [](mx::x) { for (int i = 0; i < 20; ++i) { x = mx::log(mx::exp(x)); } diff --git a/mlx/fft.cpp b/mlx/fft.cpp index c41b37843..961c1226c 100644 --- a/mlx/fft.cpp +++ b/mlx/fft.cpp @@ -3,7 +3,6 @@ #include #include #include -#include #include "mlx/fft.h" #include "mlx/ops.h" diff --git a/mlx/fft.h b/mlx/fft.h index 2f02da73b..1ccdf300d 100644 --- a/mlx/fft.h +++ b/mlx/fft.h @@ -6,8 +6,11 @@ #include "array.h" #include "device.h" +#include "mlx/mlx.h" #include "utils.h" +namespace mx = mlx::core; + namespace mlx::core::fft { /** Compute the n-dimensional Fourier Transform. */ @@ -146,4 +149,48 @@ inline array irfft2( return irfftn(a, axes, s); } +inline array stft( + const array& x, + int n_fft = 2048, + int hop_length = -1, + int win_length = -1, + const array& window = mx::array({}), + bool center = true, + const std::string& pad_mode = "reflect", + bool normalized = false, + bool onesided = true, + StreamOrDevice s = {}) { + return mlx::core::fft::stft( + x, + n_fft, + hop_length, + win_length, + window, + center, + pad_mode, + normalized, + onesided, + s); +} + +inline array istft( + const array& stft_matrix, + int hop_length = -1, + int win_length = -1, + const array& window = mx::array({}), + bool center = true, + int length = -1, + bool normalized = false, + StreamOrDevice s = {}) { + return mlx::core::fft::istft( + stft_matrix, + hop_length, + win_length, + window, + center, + length, + normalized, + s); +} + } // namespace mlx::core::fft diff --git a/python/src/fft.cpp b/python/src/fft.cpp index 675784680..aadb3893f 100644 --- a/python/src/fft.cpp +++ b/python/src/fft.cpp @@ -464,8 +464,8 @@ void init_fft(nb::module_& parent_module) { "stft", [](const mx::array& x, int n_fft = 2048, - int hop_length = -1, - int win_length = -1, + std::optional hop_length = std::nullopt, + std::optional win_length = std::nullopt, const mx::array& window = mx::array(), bool center = true, const std::string& pad_mode = "reflect", @@ -486,8 +486,8 @@ void init_fft(nb::module_& parent_module) { }, "x"_a, "n_fft"_a = 2048, - "hop_length"_a = -1, - "win_length"_a = -1, + "hop_length"_a = nb::none(), + "win_length"_a = nb::none(), "window"_a = mx::array(), "center"_a = true, "pad_mode"_a = "reflect", @@ -515,8 +515,8 @@ void init_fft(nb::module_& parent_module) { m.def( "istft", [](const mx::array& stft_matrix, - int hop_length = -1, - int win_length = -1, + std::optional hop_length = std::nullopt, + std::optional win_length = std::nullopt, const mx::array& window = mx::array(), bool center = true, int length = -1, @@ -533,8 +533,8 @@ void init_fft(nb::module_& parent_module) { s); }, "stft_matrix"_a, - "hop_length"_a = -1, - "win_length"_a = -1, + "hop_length"_a = nb::none(), + "win_length"_a = nb::none(), "window"_a = mx::array(), "center"_a = true, "length"_a = -1, From eeaf1fa463a70768109ae3a0e2520ec05b448011 Mon Sep 17 00:00:00 2001 From: Param Thakkar <128291516+ParamThakkar123@users.noreply.github.com> Date: Wed, 7 May 2025 15:24:01 +0530 Subject: [PATCH 9/9] Update fft_tests.cpp --- tests/fft_tests.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/fft_tests.cpp b/tests/fft_tests.cpp index 4373e2920..00c80bc0a 100644 --- a/tests/fft_tests.cpp +++ b/tests/fft_tests.cpp @@ -309,7 +309,6 @@ TEST_CASE("test fft grads") { CHECK_EQ(vjp_out.shape(), Shape{5, 5}); } -<<<<<<< HEAD TEST_CASE("test stft and istft") { int n_fft = 4; int hop_length = 2; @@ -383,7 +382,8 @@ TEST_CASE("test stft and istft") { CHECK_EQ(stft_result.shape(1), n_fft); } } -== == == = TEST_CASE("test fftshift and ifftshift") { + +TEST_CASE("test fftshift and ifftshift") { // Test 1D array with even length auto x = arange(8); auto y = fft::fftshift(x); @@ -440,4 +440,3 @@ TEST_CASE("test stft and istft") { CHECK_THROWS_AS(fft::ifftshift(x, {3}), std::invalid_argument); CHECK_THROWS_AS(fft::ifftshift(x, {-5}), std::invalid_argument); } ->>>>>>> 5a1a5d5ed16f69af7c3ce56dd94e4502661e1565