diff --git a/benchmarks/cpp/autograd.cpp b/benchmarks/cpp/autograd.cpp index b4303a840..fcdf0c5d6 100644 --- a/benchmarks/cpp/autograd.cpp +++ b/benchmarks/cpp/autograd.cpp @@ -10,7 +10,7 @@ namespace mx = mlx::core; void time_value_and_grad() { auto x = mx::ones({200, 1000}); mx::eval(x); - auto fn = [](mx::array x) { + auto fn = [](mx::x) { for (int i = 0; i < 20; ++i) { x = mx::log(mx::exp(x)); } diff --git a/mlx/fft.cpp b/mlx/fft.cpp index 6510faec1..8d06c7c54 100644 --- a/mlx/fft.cpp +++ b/mlx/fft.cpp @@ -1,4 +1,6 @@ // Copyright © 2023 Apple Inc. + +#include #include #include @@ -189,6 +191,124 @@ array irfftn(const array& a, StreamOrDevice s /* = {} */) { return fft_impl(a, true, true, s); } +array stft( + const array& x, + int n_fft = 2048, + int hop_length = -1, + int win_length = -1, + const array& window, + bool center = true, + const std::string& pad_mode = "reflect", + bool normalized = false, + bool onesided = true, + StreamOrDevice s /* = {} */) { + if (hop_length == -1) + hop_length = n_fft / 4; + if (win_length == -1) + win_length = n_fft; + + array win = (window.size() == 0) ? ones({win_length}, float32, s) : window; + + if (win_length < n_fft) { + int pad_left = (n_fft - win_length) / 2; + int pad_right = n_fft - win_length - pad_left; + win = mlx::core::pad( + win, {{pad_left, pad_right}}, array(0, float32), "constant", s); + } + + array padded_x = x; + if (center) { + int pad_width = n_fft / 2; + padded_x = mlx::core::pad( + padded_x, {{pad_width, pad_width}}, array(0, x.dtype()), pad_mode, s); + } + + int n_frames = 1 + (padded_x.shape(0) - n_fft) / hop_length; + + Shape strided_shape = {n_frames, n_fft}; + Strides strided_strides = { + hop_length * static_cast(sizeof(float32)), + static_cast(sizeof(float32))}; + array frames = as_strided(padded_x, strided_shape, strided_strides, 0, s); + + array stacked_frames = multiply(frames, win, s); + array stft_result = mlx::core::fft::rfftn(stacked_frames, {n_fft}, {-1}, s); + + if (normalized) { + array n_fft_array = full({1}, static_cast(n_fft), float32, s); + stft_result = divide(stft_result, sqrt(n_fft_array, s), s); + } + + if (onesided) { + stft_result = slice(stft_result, {}, {n_fft / 2 + 1}, s); + } + + return stft_result; +} + +array istft( + const array& stft_matrix, + int hop_length = -1, + int win_length = -1, + const array& window, + bool center = true, + int length = -1, + bool normalized = false, + StreamOrDevice s /* = {} */) { + int n_fft = (stft_matrix.shape(-1) - 1) * 2; + if (hop_length == -1) + hop_length = n_fft / 4; + if (win_length == -1) + win_length = n_fft; + + array win = (window.size() == 0) ? ones({win_length}, float32, s) : window; + + if (win_length < n_fft) { + int pad_left = (n_fft - win_length) / 2; + int pad_right = n_fft - win_length - pad_left; + win = mlx::core::pad( + win, {{pad_left, pad_right}}, array(0, float32), "constant", s); + } + + array frames = mlx::core::fft::irfftn(stft_matrix, {n_fft}, {-1}, s); + frames = multiply(frames, win, s); + + int signal_length = (frames.shape(0) - 1) * hop_length + n_fft; + array signal = zeros({signal_length}, float32, s); + array window_sum = zeros({signal_length}, float32, s); + + Shape strided_shape = {frames.shape(0), n_fft}; + Strides strided_strides = { + hop_length * static_cast(sizeof(float32)), + static_cast(sizeof(float32))}; + array signal_strided = + as_strided(signal, strided_shape, strided_strides, 0, s); + array window_sum_strided = + as_strided(window_sum, strided_shape, strided_strides, 0, s); + + signal_strided = add(signal_strided, frames, s); + window_sum_strided = add(window_sum_strided, win, s); + + signal = divide(signal, window_sum, s); + + if (center) { + int pad_width = n_fft / 2; + signal = slice(signal, {pad_width}, {signal.shape(0) - pad_width}, s); + } + + if (length > 0) { + if (signal.shape(0) > length) { + signal = slice(signal, {0}, {length}, s); + } else if (signal.shape(0) < length) { + int pad_length = length - signal.shape(0); + signal = mlx::core::pad( + signal, {{0, pad_length}}, array(0, signal.dtype()), "constant", s); + } + } + + return signal; +} + array fftshift( const array& a, const std::vector& axes, @@ -258,5 +378,4 @@ array ifftshift(const array& a, StreamOrDevice s /* = {} */) { std::iota(axes.begin(), axes.end(), 0); return ifftshift(a, axes, s); } - -} // namespace mlx::core::fft +} // namespace mlx::core::fft \ No newline at end of file diff --git a/mlx/fft.h b/mlx/fft.h index 163e06b80..b329c454c 100644 --- a/mlx/fft.h +++ b/mlx/fft.h @@ -6,8 +6,11 @@ #include "array.h" #include "device.h" +#include "mlx/mlx.h" #include "utils.h" +namespace mx = mlx::core; + namespace mlx::core::fft { /** Compute the n-dimensional Fourier Transform. */ @@ -164,4 +167,48 @@ array ifftshift( const std::vector& axes, StreamOrDevice s = {}); +inline array stft( + const array& x, + int n_fft = 2048, + int hop_length = -1, + int win_length = -1, + const array& window = mx::array({}), + bool center = true, + const std::string& pad_mode = "reflect", + bool normalized = false, + bool onesided = true, + StreamOrDevice s = {}) { + return mlx::core::fft::stft( + x, + n_fft, + hop_length, + win_length, + window, + center, + pad_mode, + normalized, + onesided, + s); +} + +inline array istft( + const array& stft_matrix, + int hop_length = -1, + int win_length = -1, + const array& window = mx::array({}), + bool center = true, + int length = -1, + bool normalized = false, + StreamOrDevice s = {}) { + return mlx::core::fft::istft( + stft_matrix, + hop_length, + win_length, + window, + center, + length, + normalized, + s); +} + } // namespace mlx::core::fft diff --git a/python/src/fft.cpp b/python/src/fft.cpp index 026f8139d..5d5f0dbc2 100644 --- a/python/src/fft.cpp +++ b/python/src/fft.cpp @@ -459,6 +459,101 @@ void init_fft(nb::module_& parent_module) { Returns: array: The real array containing the inverse of :func:`rfftn`. )pbdoc"); + + m.def( + "stft", + [](const mx::array& x, + int n_fft = 2048, + std::optional hop_length = std::nullopt, + std::optional win_length = std::nullopt, + const mx::array& window = mx::array(), + bool center = true, + const std::string& pad_mode = "reflect", + bool normalized = false, + bool onesided = true, + mx::StreamOrDevice s = {}) { + return mx::stft::stft( + x, + n_fft, + hop_length, + win_length, + window, + center, + pad_mode, + normalized, + onesided, + s); + }, + "x"_a, + "n_fft"_a = 2048, + "hop_length"_a = nb::none(), + "win_length"_a = nb::none(), + "window"_a = mx::array(), + "center"_a = true, + "pad_mode"_a = "reflect", + "normalized"_a = false, + "onesided"_a = true, + "stream"_a = nb::none(), + R"pbdoc( + Short-time Fourier Transform (STFT). + + Args: + x (array): Input signal. + n_fft (int, optional): Number of FFT points. Default is 2048. + hop_length (int, optional): Number of samples between successive frames. Default is `n_fft // 4`. + win_length (int, optional): Window size. Default is `n_fft`. + window (array, optional): Window function. Default is a rectangular window. + center (bool, optional): Whether to pad the signal to center the frames. Default is True. + pad_mode (str, optional): Padding mode. Default is "reflect". + normalized (bool, optional): Whether to normalize the STFT. Default is False. + onesided (bool, optional): Whether to return a one-sided STFT. Default is True. + + Returns: + array: The STFT of the input signal. + )pbdoc"); + + m.def( + "istft", + [](const mx::array& stft_matrix, + std::optional hop_length = std::nullopt, + std::optional win_length = std::nullopt, + const mx::array& window = mx::array(), + bool center = true, + int length = -1, + bool normalized = false, + mx::StreamOrDevice s = {}) { + return mx::stft::istft( + stft_matrix, + hop_length, + win_length, + window, + center, + length, + normalized, + s); + }, + "stft_matrix"_a, + "hop_length"_a = nb::none(), + "win_length"_a = nb::none(), + "window"_a = mx::array(), + "center"_a = true, + "length"_a = -1, + "normalized"_a = false, + "stream"_a = nb::none(), + R"pbdoc( + Inverse Short-time Fourier Transform (ISTFT). + + Args: + stft_matrix (array): Input STFT matrix. + hop_length (int, optional): Number of samples between successive frames. Default is `n_fft // 4`. + win_length (int, optional): Window size. Default is `n_fft`. + window (array, optional): Window function. Default is a rectangular window. + center (bool, optional): Whether the signal was padded to center the frames. Default is True. + length (int, optional): Length of the output signal. Default is inferred from the STFT matrix. + normalized (bool, optional): Whether the STFT was normalized. Default is False. + + Returns: + array: The reconstructed signal. m.def( "fftshift", [](const mx::array& a, diff --git a/tests/fft_tests.cpp b/tests/fft_tests.cpp index b9e2d1bcc..d20f06c71 100644 --- a/tests/fft_tests.cpp +++ b/tests/fft_tests.cpp @@ -307,6 +307,80 @@ TEST_CASE("test fft grads") { CHECK_EQ(vjp_out.shape(), Shape{5, 5}); } +TEST_CASE("test stft and istft") { + int n_fft = 4; + int hop_length = 2; + int win_length = 4; + + array signal = array({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, float32); + + array window = array({0.5, 1.0, 1.0, 0.5}, float32); + + SUBCASE("stft basic functionality") { + auto stft_result = fft::stft(signal, n_fft, hop_length, win_length, window); + + CHECK_EQ(stft_result.shape(0), 4); + CHECK_EQ(stft_result.shape(1), 3); + + CHECK_EQ(stft_result.dtype(), complex64); + } + + SUBCASE("istft reconstruction") { + auto stft_result = fft::stft(signal, n_fft, hop_length, win_length, window); + auto reconstructed_signal = + fft::istft(stft_result, hop_length, win_length, window); + + CHECK_EQ(reconstructed_signal.shape(0), signal.shape(0)); + CHECK(allclose(signal, reconstructed_signal, 1e-5, 1e-5).item()); + } + + SUBCASE("stft with default parameters") { + auto stft_result = fft::stft(signal); + + CHECK_EQ(stft_result.shape(0), 5); + CHECK_EQ(stft_result.shape(1), 3); + + CHECK_EQ(stft_result.dtype(), complex64); + } + + SUBCASE("istft with length parameter") { + auto stft_result = fft::stft(signal, n_fft, hop_length, win_length, window); + int length = 6; + auto reconstructed_signal = + fft::istft(stft_result, hop_length, win_length, window, true, length); + + CHECK_EQ(reconstructed_signal.shape(0), length); + + CHECK( + allclose(slice(signal, {0}, {length}), reconstructed_signal, 1e-5, 1e-5) + .item()); + } + + SUBCASE("stft and istft with normalization") { + auto stft_result = fft::stft( + signal, n_fft, hop_length, win_length, window, true, "reflect", true); + auto reconstructed_signal = + fft::istft(stft_result, hop_length, win_length, window, true, -1, true); + + CHECK(allclose(signal, reconstructed_signal, 1e-5, 1e-5).item()); + } + + SUBCASE("stft with onesided=False") { + auto stft_result = fft::stft( + signal, + n_fft, + hop_length, + win_length, + window, + true, + "reflect", + false, + false); + + CHECK_EQ(stft_result.shape(1), n_fft); + } +} + TEST_CASE("test fftshift and ifftshift") { // Test 1D array with even length auto x = arange(8);