Added stft and istft in the header

This commit is contained in:
paramthakkar123 2025-04-23 19:19:04 +05:30
parent a963a15b8d
commit f7f323f6ae
4 changed files with 56 additions and 10 deletions

View File

@ -10,7 +10,7 @@ namespace mx = mlx::core;
void time_value_and_grad() {
auto x = mx::ones({200, 1000});
mx::eval(x);
auto fn = [](mx::array x) {
auto fn = [](mx::x) {
for (int i = 0; i < 20; ++i) {
x = mx::log(mx::exp(x));
}

View File

@ -3,7 +3,6 @@
#include <cmath>
#include <numeric>
#include <set>
#include <sstream>
#include "mlx/fft.h"
#include "mlx/ops.h"

View File

@ -6,8 +6,11 @@
#include "array.h"
#include "device.h"
#include "mlx/mlx.h"
#include "utils.h"
namespace mx = mlx::core;
namespace mlx::core::fft {
/** Compute the n-dimensional Fourier Transform. */
@ -146,4 +149,48 @@ inline array irfft2(
return irfftn(a, axes, s);
}
inline array stft(
const array& x,
int n_fft = 2048,
int hop_length = -1,
int win_length = -1,
const array& window = mx::array({}),
bool center = true,
const std::string& pad_mode = "reflect",
bool normalized = false,
bool onesided = true,
StreamOrDevice s = {}) {
return mlx::core::fft::stft(
x,
n_fft,
hop_length,
win_length,
window,
center,
pad_mode,
normalized,
onesided,
s);
}
inline array istft(
const array& stft_matrix,
int hop_length = -1,
int win_length = -1,
const array& window = mx::array({}),
bool center = true,
int length = -1,
bool normalized = false,
StreamOrDevice s = {}) {
return mlx::core::fft::istft(
stft_matrix,
hop_length,
win_length,
window,
center,
length,
normalized,
s);
}
} // namespace mlx::core::fft

View File

@ -464,8 +464,8 @@ void init_fft(nb::module_& parent_module) {
"stft",
[](const mx::array& x,
int n_fft = 2048,
int hop_length = -1,
int win_length = -1,
std::optional<int> hop_length = std::nullopt,
std::optional<int> win_length = std::nullopt,
const mx::array& window = mx::array(),
bool center = true,
const std::string& pad_mode = "reflect",
@ -486,8 +486,8 @@ void init_fft(nb::module_& parent_module) {
},
"x"_a,
"n_fft"_a = 2048,
"hop_length"_a = -1,
"win_length"_a = -1,
"hop_length"_a = nb::none(),
"win_length"_a = nb::none(),
"window"_a = mx::array(),
"center"_a = true,
"pad_mode"_a = "reflect",
@ -515,8 +515,8 @@ void init_fft(nb::module_& parent_module) {
m.def(
"istft",
[](const mx::array& stft_matrix,
int hop_length = -1,
int win_length = -1,
std::optional<int> hop_length = std::nullopt,
std::optional<int> win_length = std::nullopt,
const mx::array& window = mx::array(),
bool center = true,
int length = -1,
@ -533,8 +533,8 @@ void init_fft(nb::module_& parent_module) {
s);
},
"stft_matrix"_a,
"hop_length"_a = -1,
"win_length"_a = -1,
"hop_length"_a = nb::none(),
"win_length"_a = nb::none(),
"window"_a = mx::array(),
"center"_a = true,
"length"_a = -1,