mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Added stft and istft in the header
This commit is contained in:
parent
a963a15b8d
commit
f7f323f6ae
@ -10,7 +10,7 @@ namespace mx = mlx::core;
|
||||
void time_value_and_grad() {
|
||||
auto x = mx::ones({200, 1000});
|
||||
mx::eval(x);
|
||||
auto fn = [](mx::array x) {
|
||||
auto fn = [](mx::x) {
|
||||
for (int i = 0; i < 20; ++i) {
|
||||
x = mx::log(mx::exp(x));
|
||||
}
|
||||
|
@ -3,7 +3,6 @@
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/fft.h"
|
||||
#include "mlx/ops.h"
|
||||
|
47
mlx/fft.h
47
mlx/fft.h
@ -6,8 +6,11 @@
|
||||
|
||||
#include "array.h"
|
||||
#include "device.h"
|
||||
#include "mlx/mlx.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
||||
namespace mlx::core::fft {
|
||||
|
||||
/** Compute the n-dimensional Fourier Transform. */
|
||||
@ -146,4 +149,48 @@ inline array irfft2(
|
||||
return irfftn(a, axes, s);
|
||||
}
|
||||
|
||||
inline array stft(
|
||||
const array& x,
|
||||
int n_fft = 2048,
|
||||
int hop_length = -1,
|
||||
int win_length = -1,
|
||||
const array& window = mx::array({}),
|
||||
bool center = true,
|
||||
const std::string& pad_mode = "reflect",
|
||||
bool normalized = false,
|
||||
bool onesided = true,
|
||||
StreamOrDevice s = {}) {
|
||||
return mlx::core::fft::stft(
|
||||
x,
|
||||
n_fft,
|
||||
hop_length,
|
||||
win_length,
|
||||
window,
|
||||
center,
|
||||
pad_mode,
|
||||
normalized,
|
||||
onesided,
|
||||
s);
|
||||
}
|
||||
|
||||
inline array istft(
|
||||
const array& stft_matrix,
|
||||
int hop_length = -1,
|
||||
int win_length = -1,
|
||||
const array& window = mx::array({}),
|
||||
bool center = true,
|
||||
int length = -1,
|
||||
bool normalized = false,
|
||||
StreamOrDevice s = {}) {
|
||||
return mlx::core::fft::istft(
|
||||
stft_matrix,
|
||||
hop_length,
|
||||
win_length,
|
||||
window,
|
||||
center,
|
||||
length,
|
||||
normalized,
|
||||
s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fft
|
||||
|
@ -464,8 +464,8 @@ void init_fft(nb::module_& parent_module) {
|
||||
"stft",
|
||||
[](const mx::array& x,
|
||||
int n_fft = 2048,
|
||||
int hop_length = -1,
|
||||
int win_length = -1,
|
||||
std::optional<int> hop_length = std::nullopt,
|
||||
std::optional<int> win_length = std::nullopt,
|
||||
const mx::array& window = mx::array(),
|
||||
bool center = true,
|
||||
const std::string& pad_mode = "reflect",
|
||||
@ -486,8 +486,8 @@ void init_fft(nb::module_& parent_module) {
|
||||
},
|
||||
"x"_a,
|
||||
"n_fft"_a = 2048,
|
||||
"hop_length"_a = -1,
|
||||
"win_length"_a = -1,
|
||||
"hop_length"_a = nb::none(),
|
||||
"win_length"_a = nb::none(),
|
||||
"window"_a = mx::array(),
|
||||
"center"_a = true,
|
||||
"pad_mode"_a = "reflect",
|
||||
@ -515,8 +515,8 @@ void init_fft(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"istft",
|
||||
[](const mx::array& stft_matrix,
|
||||
int hop_length = -1,
|
||||
int win_length = -1,
|
||||
std::optional<int> hop_length = std::nullopt,
|
||||
std::optional<int> win_length = std::nullopt,
|
||||
const mx::array& window = mx::array(),
|
||||
bool center = true,
|
||||
int length = -1,
|
||||
@ -533,8 +533,8 @@ void init_fft(nb::module_& parent_module) {
|
||||
s);
|
||||
},
|
||||
"stft_matrix"_a,
|
||||
"hop_length"_a = -1,
|
||||
"win_length"_a = -1,
|
||||
"hop_length"_a = nb::none(),
|
||||
"win_length"_a = nb::none(),
|
||||
"window"_a = mx::array(),
|
||||
"center"_a = true,
|
||||
"length"_a = -1,
|
||||
|
Loading…
Reference in New Issue
Block a user