This commit is contained in:
Param Thakkar 2025-06-18 10:49:45 +08:00 committed by GitHub
commit 445478c98b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 338 additions and 3 deletions

View File

@ -10,7 +10,7 @@ namespace mx = mlx::core;
void time_value_and_grad() {
auto x = mx::ones({200, 1000});
mx::eval(x);
auto fn = [](mx::array x) {
auto fn = [](mx::x) {
for (int i = 0; i < 20; ++i) {
x = mx::log(mx::exp(x));
}

View File

@ -1,4 +1,6 @@
// Copyright © 2023 Apple Inc.
#include <cmath>
#include <numeric>
#include <set>
@ -189,6 +191,124 @@ array irfftn(const array& a, StreamOrDevice s /* = {} */) {
return fft_impl(a, true, true, s);
}
array stft(
const array& x,
int n_fft = 2048,
int hop_length = -1,
int win_length = -1,
const array& window,
bool center = true,
const std::string& pad_mode = "reflect",
bool normalized = false,
bool onesided = true,
StreamOrDevice s /* = {} */) {
if (hop_length == -1)
hop_length = n_fft / 4;
if (win_length == -1)
win_length = n_fft;
array win = (window.size() == 0) ? ones({win_length}, float32, s) : window;
if (win_length < n_fft) {
int pad_left = (n_fft - win_length) / 2;
int pad_right = n_fft - win_length - pad_left;
win = mlx::core::pad(
win, {{pad_left, pad_right}}, array(0, float32), "constant", s);
}
array padded_x = x;
if (center) {
int pad_width = n_fft / 2;
padded_x = mlx::core::pad(
padded_x, {{pad_width, pad_width}}, array(0, x.dtype()), pad_mode, s);
}
int n_frames = 1 + (padded_x.shape(0) - n_fft) / hop_length;
Shape strided_shape = {n_frames, n_fft};
Strides strided_strides = {
hop_length * static_cast<long long>(sizeof(float32)),
static_cast<long long>(sizeof(float32))};
array frames = as_strided(padded_x, strided_shape, strided_strides, 0, s);
array stacked_frames = multiply(frames, win, s);
array stft_result = mlx::core::fft::rfftn(stacked_frames, {n_fft}, {-1}, s);
if (normalized) {
array n_fft_array = full({1}, static_cast<float>(n_fft), float32, s);
stft_result = divide(stft_result, sqrt(n_fft_array, s), s);
}
if (onesided) {
stft_result = slice(stft_result, {}, {n_fft / 2 + 1}, s);
}
return stft_result;
}
array istft(
const array& stft_matrix,
int hop_length = -1,
int win_length = -1,
const array& window,
bool center = true,
int length = -1,
bool normalized = false,
StreamOrDevice s /* = {} */) {
int n_fft = (stft_matrix.shape(-1) - 1) * 2;
if (hop_length == -1)
hop_length = n_fft / 4;
if (win_length == -1)
win_length = n_fft;
array win = (window.size() == 0) ? ones({win_length}, float32, s) : window;
if (win_length < n_fft) {
int pad_left = (n_fft - win_length) / 2;
int pad_right = n_fft - win_length - pad_left;
win = mlx::core::pad(
win, {{pad_left, pad_right}}, array(0, float32), "constant", s);
}
array frames = mlx::core::fft::irfftn(stft_matrix, {n_fft}, {-1}, s);
frames = multiply(frames, win, s);
int signal_length = (frames.shape(0) - 1) * hop_length + n_fft;
array signal = zeros({signal_length}, float32, s);
array window_sum = zeros({signal_length}, float32, s);
Shape strided_shape = {frames.shape(0), n_fft};
Strides strided_strides = {
hop_length * static_cast<long long>(sizeof(float32)),
static_cast<long long>(sizeof(float32))};
array signal_strided =
as_strided(signal, strided_shape, strided_strides, 0, s);
array window_sum_strided =
as_strided(window_sum, strided_shape, strided_strides, 0, s);
signal_strided = add(signal_strided, frames, s);
window_sum_strided = add(window_sum_strided, win, s);
signal = divide(signal, window_sum, s);
if (center) {
int pad_width = n_fft / 2;
signal = slice(signal, {pad_width}, {signal.shape(0) - pad_width}, s);
}
if (length > 0) {
if (signal.shape(0) > length) {
signal = slice(signal, {0}, {length}, s);
} else if (signal.shape(0) < length) {
int pad_length = length - signal.shape(0);
signal = mlx::core::pad(
signal, {{0, pad_length}}, array(0, signal.dtype()), "constant", s);
}
}
return signal;
}
array fftshift(
const array& a,
const std::vector<int>& axes,
@ -258,5 +378,4 @@ array ifftshift(const array& a, StreamOrDevice s /* = {} */) {
std::iota(axes.begin(), axes.end(), 0);
return ifftshift(a, axes, s);
}
} // namespace mlx::core::fft
} // namespace mlx::core::fft

View File

@ -6,8 +6,11 @@
#include "array.h"
#include "device.h"
#include "mlx/mlx.h"
#include "utils.h"
namespace mx = mlx::core;
namespace mlx::core::fft {
/** Compute the n-dimensional Fourier Transform. */
@ -164,4 +167,48 @@ array ifftshift(
const std::vector<int>& axes,
StreamOrDevice s = {});
inline array stft(
const array& x,
int n_fft = 2048,
int hop_length = -1,
int win_length = -1,
const array& window = mx::array({}),
bool center = true,
const std::string& pad_mode = "reflect",
bool normalized = false,
bool onesided = true,
StreamOrDevice s = {}) {
return mlx::core::fft::stft(
x,
n_fft,
hop_length,
win_length,
window,
center,
pad_mode,
normalized,
onesided,
s);
}
inline array istft(
const array& stft_matrix,
int hop_length = -1,
int win_length = -1,
const array& window = mx::array({}),
bool center = true,
int length = -1,
bool normalized = false,
StreamOrDevice s = {}) {
return mlx::core::fft::istft(
stft_matrix,
hop_length,
win_length,
window,
center,
length,
normalized,
s);
}
} // namespace mlx::core::fft

View File

@ -459,6 +459,101 @@ void init_fft(nb::module_& parent_module) {
Returns:
array: The real array containing the inverse of :func:`rfftn`.
)pbdoc");
m.def(
"stft",
[](const mx::array& x,
int n_fft = 2048,
std::optional<int> hop_length = std::nullopt,
std::optional<int> win_length = std::nullopt,
const mx::array& window = mx::array(),
bool center = true,
const std::string& pad_mode = "reflect",
bool normalized = false,
bool onesided = true,
mx::StreamOrDevice s = {}) {
return mx::stft::stft(
x,
n_fft,
hop_length,
win_length,
window,
center,
pad_mode,
normalized,
onesided,
s);
},
"x"_a,
"n_fft"_a = 2048,
"hop_length"_a = nb::none(),
"win_length"_a = nb::none(),
"window"_a = mx::array(),
"center"_a = true,
"pad_mode"_a = "reflect",
"normalized"_a = false,
"onesided"_a = true,
"stream"_a = nb::none(),
R"pbdoc(
Short-time Fourier Transform (STFT).
Args:
x (array): Input signal.
n_fft (int, optional): Number of FFT points. Default is 2048.
hop_length (int, optional): Number of samples between successive frames. Default is `n_fft // 4`.
win_length (int, optional): Window size. Default is `n_fft`.
window (array, optional): Window function. Default is a rectangular window.
center (bool, optional): Whether to pad the signal to center the frames. Default is True.
pad_mode (str, optional): Padding mode. Default is "reflect".
normalized (bool, optional): Whether to normalize the STFT. Default is False.
onesided (bool, optional): Whether to return a one-sided STFT. Default is True.
Returns:
array: The STFT of the input signal.
)pbdoc");
m.def(
"istft",
[](const mx::array& stft_matrix,
std::optional<int> hop_length = std::nullopt,
std::optional<int> win_length = std::nullopt,
const mx::array& window = mx::array(),
bool center = true,
int length = -1,
bool normalized = false,
mx::StreamOrDevice s = {}) {
return mx::stft::istft(
stft_matrix,
hop_length,
win_length,
window,
center,
length,
normalized,
s);
},
"stft_matrix"_a,
"hop_length"_a = nb::none(),
"win_length"_a = nb::none(),
"window"_a = mx::array(),
"center"_a = true,
"length"_a = -1,
"normalized"_a = false,
"stream"_a = nb::none(),
R"pbdoc(
Inverse Short-time Fourier Transform (ISTFT).
Args:
stft_matrix (array): Input STFT matrix.
hop_length (int, optional): Number of samples between successive frames. Default is `n_fft // 4`.
win_length (int, optional): Window size. Default is `n_fft`.
window (array, optional): Window function. Default is a rectangular window.
center (bool, optional): Whether the signal was padded to center the frames. Default is True.
length (int, optional): Length of the output signal. Default is inferred from the STFT matrix.
normalized (bool, optional): Whether the STFT was normalized. Default is False.
Returns:
array: The reconstructed signal.
m.def(
"fftshift",
[](const mx::array& a,

View File

@ -307,6 +307,80 @@ TEST_CASE("test fft grads") {
CHECK_EQ(vjp_out.shape(), Shape{5, 5});
}
TEST_CASE("test stft and istft") {
int n_fft = 4;
int hop_length = 2;
int win_length = 4;
array signal = array({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, float32);
array window = array({0.5, 1.0, 1.0, 0.5}, float32);
SUBCASE("stft basic functionality") {
auto stft_result = fft::stft(signal, n_fft, hop_length, win_length, window);
CHECK_EQ(stft_result.shape(0), 4);
CHECK_EQ(stft_result.shape(1), 3);
CHECK_EQ(stft_result.dtype(), complex64);
}
SUBCASE("istft reconstruction") {
auto stft_result = fft::stft(signal, n_fft, hop_length, win_length, window);
auto reconstructed_signal =
fft::istft(stft_result, hop_length, win_length, window);
CHECK_EQ(reconstructed_signal.shape(0), signal.shape(0));
CHECK(allclose(signal, reconstructed_signal, 1e-5, 1e-5).item<bool>());
}
SUBCASE("stft with default parameters") {
auto stft_result = fft::stft(signal);
CHECK_EQ(stft_result.shape(0), 5);
CHECK_EQ(stft_result.shape(1), 3);
CHECK_EQ(stft_result.dtype(), complex64);
}
SUBCASE("istft with length parameter") {
auto stft_result = fft::stft(signal, n_fft, hop_length, win_length, window);
int length = 6;
auto reconstructed_signal =
fft::istft(stft_result, hop_length, win_length, window, true, length);
CHECK_EQ(reconstructed_signal.shape(0), length);
CHECK(
allclose(slice(signal, {0}, {length}), reconstructed_signal, 1e-5, 1e-5)
.item<bool>());
}
SUBCASE("stft and istft with normalization") {
auto stft_result = fft::stft(
signal, n_fft, hop_length, win_length, window, true, "reflect", true);
auto reconstructed_signal =
fft::istft(stft_result, hop_length, win_length, window, true, -1, true);
CHECK(allclose(signal, reconstructed_signal, 1e-5, 1e-5).item<bool>());
}
SUBCASE("stft with onesided=False") {
auto stft_result = fft::stft(
signal,
n_fft,
hop_length,
win_length,
window,
true,
"reflect",
false,
false);
CHECK_EQ(stft_result.shape(1), n_fft);
}
}
TEST_CASE("test fftshift and ifftshift") {
// Test 1D array with even length
auto x = arange(8);