From f7f323f6aeada6f345320d12ed3b4f84fde45502 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Wed, 23 Apr 2025 19:19:04 +0530 Subject: [PATCH] Added stft and istft in the header --- benchmarks/cpp/autograd.cpp | 2 +- mlx/fft.cpp | 1 - mlx/fft.h | 47 +++++++++++++++++++++++++++++++++++++ python/src/fft.cpp | 16 ++++++------- 4 files changed, 56 insertions(+), 10 deletions(-) diff --git a/benchmarks/cpp/autograd.cpp b/benchmarks/cpp/autograd.cpp index b4303a840..fcdf0c5d6 100644 --- a/benchmarks/cpp/autograd.cpp +++ b/benchmarks/cpp/autograd.cpp @@ -10,7 +10,7 @@ namespace mx = mlx::core; void time_value_and_grad() { auto x = mx::ones({200, 1000}); mx::eval(x); - auto fn = [](mx::array x) { + auto fn = [](mx::x) { for (int i = 0; i < 20; ++i) { x = mx::log(mx::exp(x)); } diff --git a/mlx/fft.cpp b/mlx/fft.cpp index c41b37843..961c1226c 100644 --- a/mlx/fft.cpp +++ b/mlx/fft.cpp @@ -3,7 +3,6 @@ #include #include #include -#include #include "mlx/fft.h" #include "mlx/ops.h" diff --git a/mlx/fft.h b/mlx/fft.h index 2f02da73b..1ccdf300d 100644 --- a/mlx/fft.h +++ b/mlx/fft.h @@ -6,8 +6,11 @@ #include "array.h" #include "device.h" +#include "mlx/mlx.h" #include "utils.h" +namespace mx = mlx::core; + namespace mlx::core::fft { /** Compute the n-dimensional Fourier Transform. */ @@ -146,4 +149,48 @@ inline array irfft2( return irfftn(a, axes, s); } +inline array stft( + const array& x, + int n_fft = 2048, + int hop_length = -1, + int win_length = -1, + const array& window = mx::array({}), + bool center = true, + const std::string& pad_mode = "reflect", + bool normalized = false, + bool onesided = true, + StreamOrDevice s = {}) { + return mlx::core::fft::stft( + x, + n_fft, + hop_length, + win_length, + window, + center, + pad_mode, + normalized, + onesided, + s); +} + +inline array istft( + const array& stft_matrix, + int hop_length = -1, + int win_length = -1, + const array& window = mx::array({}), + bool center = true, + int length = -1, + bool normalized = false, + StreamOrDevice s = {}) { + return mlx::core::fft::istft( + stft_matrix, + hop_length, + win_length, + window, + center, + length, + normalized, + s); +} + } // namespace mlx::core::fft diff --git a/python/src/fft.cpp b/python/src/fft.cpp index 675784680..aadb3893f 100644 --- a/python/src/fft.cpp +++ b/python/src/fft.cpp @@ -464,8 +464,8 @@ void init_fft(nb::module_& parent_module) { "stft", [](const mx::array& x, int n_fft = 2048, - int hop_length = -1, - int win_length = -1, + std::optional hop_length = std::nullopt, + std::optional win_length = std::nullopt, const mx::array& window = mx::array(), bool center = true, const std::string& pad_mode = "reflect", @@ -486,8 +486,8 @@ void init_fft(nb::module_& parent_module) { }, "x"_a, "n_fft"_a = 2048, - "hop_length"_a = -1, - "win_length"_a = -1, + "hop_length"_a = nb::none(), + "win_length"_a = nb::none(), "window"_a = mx::array(), "center"_a = true, "pad_mode"_a = "reflect", @@ -515,8 +515,8 @@ void init_fft(nb::module_& parent_module) { m.def( "istft", [](const mx::array& stft_matrix, - int hop_length = -1, - int win_length = -1, + std::optional hop_length = std::nullopt, + std::optional win_length = std::nullopt, const mx::array& window = mx::array(), bool center = true, int length = -1, @@ -533,8 +533,8 @@ void init_fft(nb::module_& parent_module) { s); }, "stft_matrix"_a, - "hop_length"_a = -1, - "win_length"_a = -1, + "hop_length"_a = nb::none(), + "win_length"_a = nb::none(), "window"_a = mx::array(), "center"_a = true, "length"_a = -1,