mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Merge eeaf1fa463
into cad5c0241c
This commit is contained in:
commit
445478c98b
@ -10,7 +10,7 @@ namespace mx = mlx::core;
|
||||
void time_value_and_grad() {
|
||||
auto x = mx::ones({200, 1000});
|
||||
mx::eval(x);
|
||||
auto fn = [](mx::array x) {
|
||||
auto fn = [](mx::x) {
|
||||
for (int i = 0; i < 20; ++i) {
|
||||
x = mx::log(mx::exp(x));
|
||||
}
|
||||
|
123
mlx/fft.cpp
123
mlx/fft.cpp
@ -1,4 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
|
||||
@ -189,6 +191,124 @@ array irfftn(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return fft_impl(a, true, true, s);
|
||||
}
|
||||
|
||||
array stft(
|
||||
const array& x,
|
||||
int n_fft = 2048,
|
||||
int hop_length = -1,
|
||||
int win_length = -1,
|
||||
const array& window,
|
||||
bool center = true,
|
||||
const std::string& pad_mode = "reflect",
|
||||
bool normalized = false,
|
||||
bool onesided = true,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (hop_length == -1)
|
||||
hop_length = n_fft / 4;
|
||||
if (win_length == -1)
|
||||
win_length = n_fft;
|
||||
|
||||
array win = (window.size() == 0) ? ones({win_length}, float32, s) : window;
|
||||
|
||||
if (win_length < n_fft) {
|
||||
int pad_left = (n_fft - win_length) / 2;
|
||||
int pad_right = n_fft - win_length - pad_left;
|
||||
win = mlx::core::pad(
|
||||
win, {{pad_left, pad_right}}, array(0, float32), "constant", s);
|
||||
}
|
||||
|
||||
array padded_x = x;
|
||||
if (center) {
|
||||
int pad_width = n_fft / 2;
|
||||
padded_x = mlx::core::pad(
|
||||
padded_x, {{pad_width, pad_width}}, array(0, x.dtype()), pad_mode, s);
|
||||
}
|
||||
|
||||
int n_frames = 1 + (padded_x.shape(0) - n_fft) / hop_length;
|
||||
|
||||
Shape strided_shape = {n_frames, n_fft};
|
||||
Strides strided_strides = {
|
||||
hop_length * static_cast<long long>(sizeof(float32)),
|
||||
static_cast<long long>(sizeof(float32))};
|
||||
array frames = as_strided(padded_x, strided_shape, strided_strides, 0, s);
|
||||
|
||||
array stacked_frames = multiply(frames, win, s);
|
||||
array stft_result = mlx::core::fft::rfftn(stacked_frames, {n_fft}, {-1}, s);
|
||||
|
||||
if (normalized) {
|
||||
array n_fft_array = full({1}, static_cast<float>(n_fft), float32, s);
|
||||
stft_result = divide(stft_result, sqrt(n_fft_array, s), s);
|
||||
}
|
||||
|
||||
if (onesided) {
|
||||
stft_result = slice(stft_result, {}, {n_fft / 2 + 1}, s);
|
||||
}
|
||||
|
||||
return stft_result;
|
||||
}
|
||||
|
||||
array istft(
|
||||
const array& stft_matrix,
|
||||
int hop_length = -1,
|
||||
int win_length = -1,
|
||||
const array& window,
|
||||
bool center = true,
|
||||
int length = -1,
|
||||
bool normalized = false,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
int n_fft = (stft_matrix.shape(-1) - 1) * 2;
|
||||
if (hop_length == -1)
|
||||
hop_length = n_fft / 4;
|
||||
if (win_length == -1)
|
||||
win_length = n_fft;
|
||||
|
||||
array win = (window.size() == 0) ? ones({win_length}, float32, s) : window;
|
||||
|
||||
if (win_length < n_fft) {
|
||||
int pad_left = (n_fft - win_length) / 2;
|
||||
int pad_right = n_fft - win_length - pad_left;
|
||||
win = mlx::core::pad(
|
||||
win, {{pad_left, pad_right}}, array(0, float32), "constant", s);
|
||||
}
|
||||
|
||||
array frames = mlx::core::fft::irfftn(stft_matrix, {n_fft}, {-1}, s);
|
||||
frames = multiply(frames, win, s);
|
||||
|
||||
int signal_length = (frames.shape(0) - 1) * hop_length + n_fft;
|
||||
array signal = zeros({signal_length}, float32, s);
|
||||
array window_sum = zeros({signal_length}, float32, s);
|
||||
|
||||
Shape strided_shape = {frames.shape(0), n_fft};
|
||||
Strides strided_strides = {
|
||||
hop_length * static_cast<long long>(sizeof(float32)),
|
||||
static_cast<long long>(sizeof(float32))};
|
||||
array signal_strided =
|
||||
as_strided(signal, strided_shape, strided_strides, 0, s);
|
||||
array window_sum_strided =
|
||||
as_strided(window_sum, strided_shape, strided_strides, 0, s);
|
||||
|
||||
signal_strided = add(signal_strided, frames, s);
|
||||
window_sum_strided = add(window_sum_strided, win, s);
|
||||
|
||||
signal = divide(signal, window_sum, s);
|
||||
|
||||
if (center) {
|
||||
int pad_width = n_fft / 2;
|
||||
signal = slice(signal, {pad_width}, {signal.shape(0) - pad_width}, s);
|
||||
}
|
||||
|
||||
if (length > 0) {
|
||||
if (signal.shape(0) > length) {
|
||||
signal = slice(signal, {0}, {length}, s);
|
||||
} else if (signal.shape(0) < length) {
|
||||
int pad_length = length - signal.shape(0);
|
||||
signal = mlx::core::pad(
|
||||
signal, {{0, pad_length}}, array(0, signal.dtype()), "constant", s);
|
||||
}
|
||||
}
|
||||
|
||||
return signal;
|
||||
}
|
||||
|
||||
array fftshift(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
@ -258,5 +378,4 @@ array ifftshift(const array& a, StreamOrDevice s /* = {} */) {
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
return ifftshift(a, axes, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fft
|
||||
} // namespace mlx::core::fft
|
47
mlx/fft.h
47
mlx/fft.h
@ -6,8 +6,11 @@
|
||||
|
||||
#include "array.h"
|
||||
#include "device.h"
|
||||
#include "mlx/mlx.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
||||
namespace mlx::core::fft {
|
||||
|
||||
/** Compute the n-dimensional Fourier Transform. */
|
||||
@ -164,4 +167,48 @@ array ifftshift(
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
inline array stft(
|
||||
const array& x,
|
||||
int n_fft = 2048,
|
||||
int hop_length = -1,
|
||||
int win_length = -1,
|
||||
const array& window = mx::array({}),
|
||||
bool center = true,
|
||||
const std::string& pad_mode = "reflect",
|
||||
bool normalized = false,
|
||||
bool onesided = true,
|
||||
StreamOrDevice s = {}) {
|
||||
return mlx::core::fft::stft(
|
||||
x,
|
||||
n_fft,
|
||||
hop_length,
|
||||
win_length,
|
||||
window,
|
||||
center,
|
||||
pad_mode,
|
||||
normalized,
|
||||
onesided,
|
||||
s);
|
||||
}
|
||||
|
||||
inline array istft(
|
||||
const array& stft_matrix,
|
||||
int hop_length = -1,
|
||||
int win_length = -1,
|
||||
const array& window = mx::array({}),
|
||||
bool center = true,
|
||||
int length = -1,
|
||||
bool normalized = false,
|
||||
StreamOrDevice s = {}) {
|
||||
return mlx::core::fft::istft(
|
||||
stft_matrix,
|
||||
hop_length,
|
||||
win_length,
|
||||
window,
|
||||
center,
|
||||
length,
|
||||
normalized,
|
||||
s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fft
|
||||
|
@ -459,6 +459,101 @@ void init_fft(nb::module_& parent_module) {
|
||||
Returns:
|
||||
array: The real array containing the inverse of :func:`rfftn`.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"stft",
|
||||
[](const mx::array& x,
|
||||
int n_fft = 2048,
|
||||
std::optional<int> hop_length = std::nullopt,
|
||||
std::optional<int> win_length = std::nullopt,
|
||||
const mx::array& window = mx::array(),
|
||||
bool center = true,
|
||||
const std::string& pad_mode = "reflect",
|
||||
bool normalized = false,
|
||||
bool onesided = true,
|
||||
mx::StreamOrDevice s = {}) {
|
||||
return mx::stft::stft(
|
||||
x,
|
||||
n_fft,
|
||||
hop_length,
|
||||
win_length,
|
||||
window,
|
||||
center,
|
||||
pad_mode,
|
||||
normalized,
|
||||
onesided,
|
||||
s);
|
||||
},
|
||||
"x"_a,
|
||||
"n_fft"_a = 2048,
|
||||
"hop_length"_a = nb::none(),
|
||||
"win_length"_a = nb::none(),
|
||||
"window"_a = mx::array(),
|
||||
"center"_a = true,
|
||||
"pad_mode"_a = "reflect",
|
||||
"normalized"_a = false,
|
||||
"onesided"_a = true,
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Short-time Fourier Transform (STFT).
|
||||
|
||||
Args:
|
||||
x (array): Input signal.
|
||||
n_fft (int, optional): Number of FFT points. Default is 2048.
|
||||
hop_length (int, optional): Number of samples between successive frames. Default is `n_fft // 4`.
|
||||
win_length (int, optional): Window size. Default is `n_fft`.
|
||||
window (array, optional): Window function. Default is a rectangular window.
|
||||
center (bool, optional): Whether to pad the signal to center the frames. Default is True.
|
||||
pad_mode (str, optional): Padding mode. Default is "reflect".
|
||||
normalized (bool, optional): Whether to normalize the STFT. Default is False.
|
||||
onesided (bool, optional): Whether to return a one-sided STFT. Default is True.
|
||||
|
||||
Returns:
|
||||
array: The STFT of the input signal.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"istft",
|
||||
[](const mx::array& stft_matrix,
|
||||
std::optional<int> hop_length = std::nullopt,
|
||||
std::optional<int> win_length = std::nullopt,
|
||||
const mx::array& window = mx::array(),
|
||||
bool center = true,
|
||||
int length = -1,
|
||||
bool normalized = false,
|
||||
mx::StreamOrDevice s = {}) {
|
||||
return mx::stft::istft(
|
||||
stft_matrix,
|
||||
hop_length,
|
||||
win_length,
|
||||
window,
|
||||
center,
|
||||
length,
|
||||
normalized,
|
||||
s);
|
||||
},
|
||||
"stft_matrix"_a,
|
||||
"hop_length"_a = nb::none(),
|
||||
"win_length"_a = nb::none(),
|
||||
"window"_a = mx::array(),
|
||||
"center"_a = true,
|
||||
"length"_a = -1,
|
||||
"normalized"_a = false,
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Inverse Short-time Fourier Transform (ISTFT).
|
||||
|
||||
Args:
|
||||
stft_matrix (array): Input STFT matrix.
|
||||
hop_length (int, optional): Number of samples between successive frames. Default is `n_fft // 4`.
|
||||
win_length (int, optional): Window size. Default is `n_fft`.
|
||||
window (array, optional): Window function. Default is a rectangular window.
|
||||
center (bool, optional): Whether the signal was padded to center the frames. Default is True.
|
||||
length (int, optional): Length of the output signal. Default is inferred from the STFT matrix.
|
||||
normalized (bool, optional): Whether the STFT was normalized. Default is False.
|
||||
|
||||
Returns:
|
||||
array: The reconstructed signal.
|
||||
m.def(
|
||||
"fftshift",
|
||||
[](const mx::array& a,
|
||||
|
@ -307,6 +307,80 @@ TEST_CASE("test fft grads") {
|
||||
CHECK_EQ(vjp_out.shape(), Shape{5, 5});
|
||||
}
|
||||
|
||||
TEST_CASE("test stft and istft") {
|
||||
int n_fft = 4;
|
||||
int hop_length = 2;
|
||||
int win_length = 4;
|
||||
|
||||
array signal = array({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, float32);
|
||||
|
||||
array window = array({0.5, 1.0, 1.0, 0.5}, float32);
|
||||
|
||||
SUBCASE("stft basic functionality") {
|
||||
auto stft_result = fft::stft(signal, n_fft, hop_length, win_length, window);
|
||||
|
||||
CHECK_EQ(stft_result.shape(0), 4);
|
||||
CHECK_EQ(stft_result.shape(1), 3);
|
||||
|
||||
CHECK_EQ(stft_result.dtype(), complex64);
|
||||
}
|
||||
|
||||
SUBCASE("istft reconstruction") {
|
||||
auto stft_result = fft::stft(signal, n_fft, hop_length, win_length, window);
|
||||
auto reconstructed_signal =
|
||||
fft::istft(stft_result, hop_length, win_length, window);
|
||||
|
||||
CHECK_EQ(reconstructed_signal.shape(0), signal.shape(0));
|
||||
CHECK(allclose(signal, reconstructed_signal, 1e-5, 1e-5).item<bool>());
|
||||
}
|
||||
|
||||
SUBCASE("stft with default parameters") {
|
||||
auto stft_result = fft::stft(signal);
|
||||
|
||||
CHECK_EQ(stft_result.shape(0), 5);
|
||||
CHECK_EQ(stft_result.shape(1), 3);
|
||||
|
||||
CHECK_EQ(stft_result.dtype(), complex64);
|
||||
}
|
||||
|
||||
SUBCASE("istft with length parameter") {
|
||||
auto stft_result = fft::stft(signal, n_fft, hop_length, win_length, window);
|
||||
int length = 6;
|
||||
auto reconstructed_signal =
|
||||
fft::istft(stft_result, hop_length, win_length, window, true, length);
|
||||
|
||||
CHECK_EQ(reconstructed_signal.shape(0), length);
|
||||
|
||||
CHECK(
|
||||
allclose(slice(signal, {0}, {length}), reconstructed_signal, 1e-5, 1e-5)
|
||||
.item<bool>());
|
||||
}
|
||||
|
||||
SUBCASE("stft and istft with normalization") {
|
||||
auto stft_result = fft::stft(
|
||||
signal, n_fft, hop_length, win_length, window, true, "reflect", true);
|
||||
auto reconstructed_signal =
|
||||
fft::istft(stft_result, hop_length, win_length, window, true, -1, true);
|
||||
|
||||
CHECK(allclose(signal, reconstructed_signal, 1e-5, 1e-5).item<bool>());
|
||||
}
|
||||
|
||||
SUBCASE("stft with onesided=False") {
|
||||
auto stft_result = fft::stft(
|
||||
signal,
|
||||
n_fft,
|
||||
hop_length,
|
||||
win_length,
|
||||
window,
|
||||
true,
|
||||
"reflect",
|
||||
false,
|
||||
false);
|
||||
|
||||
CHECK_EQ(stft_result.shape(1), n_fft);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test fftshift and ifftshift") {
|
||||
// Test 1D array with even length
|
||||
auto x = arange(8);
|
||||
|
Loading…
Reference in New Issue
Block a user