mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Added stft and istft in the header
This commit is contained in:
parent
a963a15b8d
commit
f7f323f6ae
@ -10,7 +10,7 @@ namespace mx = mlx::core;
|
|||||||
void time_value_and_grad() {
|
void time_value_and_grad() {
|
||||||
auto x = mx::ones({200, 1000});
|
auto x = mx::ones({200, 1000});
|
||||||
mx::eval(x);
|
mx::eval(x);
|
||||||
auto fn = [](mx::array x) {
|
auto fn = [](mx::x) {
|
||||||
for (int i = 0; i < 20; ++i) {
|
for (int i = 0; i < 20; ++i) {
|
||||||
x = mx::log(mx::exp(x));
|
x = mx::log(mx::exp(x));
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#include "mlx/fft.h"
|
#include "mlx/fft.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
|
47
mlx/fft.h
47
mlx/fft.h
@ -6,8 +6,11 @@
|
|||||||
|
|
||||||
#include "array.h"
|
#include "array.h"
|
||||||
#include "device.h"
|
#include "device.h"
|
||||||
|
#include "mlx/mlx.h"
|
||||||
#include "utils.h"
|
#include "utils.h"
|
||||||
|
|
||||||
|
namespace mx = mlx::core;
|
||||||
|
|
||||||
namespace mlx::core::fft {
|
namespace mlx::core::fft {
|
||||||
|
|
||||||
/** Compute the n-dimensional Fourier Transform. */
|
/** Compute the n-dimensional Fourier Transform. */
|
||||||
@ -146,4 +149,48 @@ inline array irfft2(
|
|||||||
return irfftn(a, axes, s);
|
return irfftn(a, axes, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline array stft(
|
||||||
|
const array& x,
|
||||||
|
int n_fft = 2048,
|
||||||
|
int hop_length = -1,
|
||||||
|
int win_length = -1,
|
||||||
|
const array& window = mx::array({}),
|
||||||
|
bool center = true,
|
||||||
|
const std::string& pad_mode = "reflect",
|
||||||
|
bool normalized = false,
|
||||||
|
bool onesided = true,
|
||||||
|
StreamOrDevice s = {}) {
|
||||||
|
return mlx::core::fft::stft(
|
||||||
|
x,
|
||||||
|
n_fft,
|
||||||
|
hop_length,
|
||||||
|
win_length,
|
||||||
|
window,
|
||||||
|
center,
|
||||||
|
pad_mode,
|
||||||
|
normalized,
|
||||||
|
onesided,
|
||||||
|
s);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline array istft(
|
||||||
|
const array& stft_matrix,
|
||||||
|
int hop_length = -1,
|
||||||
|
int win_length = -1,
|
||||||
|
const array& window = mx::array({}),
|
||||||
|
bool center = true,
|
||||||
|
int length = -1,
|
||||||
|
bool normalized = false,
|
||||||
|
StreamOrDevice s = {}) {
|
||||||
|
return mlx::core::fft::istft(
|
||||||
|
stft_matrix,
|
||||||
|
hop_length,
|
||||||
|
win_length,
|
||||||
|
window,
|
||||||
|
center,
|
||||||
|
length,
|
||||||
|
normalized,
|
||||||
|
s);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::fft
|
} // namespace mlx::core::fft
|
||||||
|
@ -464,8 +464,8 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
"stft",
|
"stft",
|
||||||
[](const mx::array& x,
|
[](const mx::array& x,
|
||||||
int n_fft = 2048,
|
int n_fft = 2048,
|
||||||
int hop_length = -1,
|
std::optional<int> hop_length = std::nullopt,
|
||||||
int win_length = -1,
|
std::optional<int> win_length = std::nullopt,
|
||||||
const mx::array& window = mx::array(),
|
const mx::array& window = mx::array(),
|
||||||
bool center = true,
|
bool center = true,
|
||||||
const std::string& pad_mode = "reflect",
|
const std::string& pad_mode = "reflect",
|
||||||
@ -486,8 +486,8 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
},
|
},
|
||||||
"x"_a,
|
"x"_a,
|
||||||
"n_fft"_a = 2048,
|
"n_fft"_a = 2048,
|
||||||
"hop_length"_a = -1,
|
"hop_length"_a = nb::none(),
|
||||||
"win_length"_a = -1,
|
"win_length"_a = nb::none(),
|
||||||
"window"_a = mx::array(),
|
"window"_a = mx::array(),
|
||||||
"center"_a = true,
|
"center"_a = true,
|
||||||
"pad_mode"_a = "reflect",
|
"pad_mode"_a = "reflect",
|
||||||
@ -515,8 +515,8 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
m.def(
|
m.def(
|
||||||
"istft",
|
"istft",
|
||||||
[](const mx::array& stft_matrix,
|
[](const mx::array& stft_matrix,
|
||||||
int hop_length = -1,
|
std::optional<int> hop_length = std::nullopt,
|
||||||
int win_length = -1,
|
std::optional<int> win_length = std::nullopt,
|
||||||
const mx::array& window = mx::array(),
|
const mx::array& window = mx::array(),
|
||||||
bool center = true,
|
bool center = true,
|
||||||
int length = -1,
|
int length = -1,
|
||||||
@ -533,8 +533,8 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
s);
|
s);
|
||||||
},
|
},
|
||||||
"stft_matrix"_a,
|
"stft_matrix"_a,
|
||||||
"hop_length"_a = -1,
|
"hop_length"_a = nb::none(),
|
||||||
"win_length"_a = -1,
|
"win_length"_a = nb::none(),
|
||||||
"window"_a = mx::array(),
|
"window"_a = mx::array(),
|
||||||
"center"_a = true,
|
"center"_a = true,
|
||||||
"length"_a = -1,
|
"length"_a = -1,
|
||||||
|
Loading…
Reference in New Issue
Block a user