Added stft and istft in the header

This commit is contained in:
paramthakkar123 2025-04-23 19:19:04 +05:30
parent a963a15b8d
commit f7f323f6ae
4 changed files with 56 additions and 10 deletions

View File

@ -10,7 +10,7 @@ namespace mx = mlx::core;
void time_value_and_grad() { void time_value_and_grad() {
auto x = mx::ones({200, 1000}); auto x = mx::ones({200, 1000});
mx::eval(x); mx::eval(x);
auto fn = [](mx::array x) { auto fn = [](mx::x) {
for (int i = 0; i < 20; ++i) { for (int i = 0; i < 20; ++i) {
x = mx::log(mx::exp(x)); x = mx::log(mx::exp(x));
} }

View File

@ -3,7 +3,6 @@
#include <cmath> #include <cmath>
#include <numeric> #include <numeric>
#include <set> #include <set>
#include <sstream>
#include "mlx/fft.h" #include "mlx/fft.h"
#include "mlx/ops.h" #include "mlx/ops.h"

View File

@ -6,8 +6,11 @@
#include "array.h" #include "array.h"
#include "device.h" #include "device.h"
#include "mlx/mlx.h"
#include "utils.h" #include "utils.h"
namespace mx = mlx::core;
namespace mlx::core::fft { namespace mlx::core::fft {
/** Compute the n-dimensional Fourier Transform. */ /** Compute the n-dimensional Fourier Transform. */
@ -146,4 +149,48 @@ inline array irfft2(
return irfftn(a, axes, s); return irfftn(a, axes, s);
} }
inline array stft(
const array& x,
int n_fft = 2048,
int hop_length = -1,
int win_length = -1,
const array& window = mx::array({}),
bool center = true,
const std::string& pad_mode = "reflect",
bool normalized = false,
bool onesided = true,
StreamOrDevice s = {}) {
return mlx::core::fft::stft(
x,
n_fft,
hop_length,
win_length,
window,
center,
pad_mode,
normalized,
onesided,
s);
}
inline array istft(
const array& stft_matrix,
int hop_length = -1,
int win_length = -1,
const array& window = mx::array({}),
bool center = true,
int length = -1,
bool normalized = false,
StreamOrDevice s = {}) {
return mlx::core::fft::istft(
stft_matrix,
hop_length,
win_length,
window,
center,
length,
normalized,
s);
}
} // namespace mlx::core::fft } // namespace mlx::core::fft

View File

@ -464,8 +464,8 @@ void init_fft(nb::module_& parent_module) {
"stft", "stft",
[](const mx::array& x, [](const mx::array& x,
int n_fft = 2048, int n_fft = 2048,
int hop_length = -1, std::optional<int> hop_length = std::nullopt,
int win_length = -1, std::optional<int> win_length = std::nullopt,
const mx::array& window = mx::array(), const mx::array& window = mx::array(),
bool center = true, bool center = true,
const std::string& pad_mode = "reflect", const std::string& pad_mode = "reflect",
@ -486,8 +486,8 @@ void init_fft(nb::module_& parent_module) {
}, },
"x"_a, "x"_a,
"n_fft"_a = 2048, "n_fft"_a = 2048,
"hop_length"_a = -1, "hop_length"_a = nb::none(),
"win_length"_a = -1, "win_length"_a = nb::none(),
"window"_a = mx::array(), "window"_a = mx::array(),
"center"_a = true, "center"_a = true,
"pad_mode"_a = "reflect", "pad_mode"_a = "reflect",
@ -515,8 +515,8 @@ void init_fft(nb::module_& parent_module) {
m.def( m.def(
"istft", "istft",
[](const mx::array& stft_matrix, [](const mx::array& stft_matrix,
int hop_length = -1, std::optional<int> hop_length = std::nullopt,
int win_length = -1, std::optional<int> win_length = std::nullopt,
const mx::array& window = mx::array(), const mx::array& window = mx::array(),
bool center = true, bool center = true,
int length = -1, int length = -1,
@ -533,8 +533,8 @@ void init_fft(nb::module_& parent_module) {
s); s);
}, },
"stft_matrix"_a, "stft_matrix"_a,
"hop_length"_a = -1, "hop_length"_a = nb::none(),
"win_length"_a = -1, "win_length"_a = nb::none(),
"window"_a = mx::array(), "window"_a = mx::array(),
"center"_a = true, "center"_a = true,
"length"_a = -1, "length"_a = -1,