mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 07:31:26 +08:00
Fixed fft.cpp and added tests
This commit is contained in:
parent
c92a2bc679
commit
0db393acd7
34
mlx/fft.cpp
34
mlx/fft.cpp
@ -225,15 +225,13 @@ array stft(
|
|||||||
|
|
||||||
int n_frames = 1 + (padded_x.shape(0) - n_fft) / hop_length;
|
int n_frames = 1 + (padded_x.shape(0) - n_fft) / hop_length;
|
||||||
|
|
||||||
std::vector<array> frames;
|
Shape strided_shape = {n_frames, n_fft};
|
||||||
for (int i = 0; i < n_frames; ++i) {
|
Strides strided_strides = {
|
||||||
array frame =
|
hop_length * static_cast<long long>(sizeof(float32)),
|
||||||
slice(padded_x, {i * hop_length}, {i * hop_length + n_fft}, s);
|
static_cast<long long>(sizeof(float32))};
|
||||||
frames.push_back(multiply(frame, win, s));
|
array frames = as_strided(padded_x, strided_shape, strided_strides, 0, s);
|
||||||
}
|
|
||||||
|
|
||||||
array stacked_frames = stack(frames, 0, s);
|
|
||||||
|
|
||||||
|
array stacked_frames = multiply(frames, win, s);
|
||||||
array stft_result = mlx::core::fft::rfftn(stacked_frames, {n_fft}, {-1}, s);
|
array stft_result = mlx::core::fft::rfftn(stacked_frames, {n_fft}, {-1}, s);
|
||||||
|
|
||||||
if (normalized) {
|
if (normalized) {
|
||||||
@ -273,23 +271,23 @@ array istft(
|
|||||||
}
|
}
|
||||||
|
|
||||||
array frames = mlx::core::fft::irfftn(stft_matrix, {n_fft}, {-1}, s);
|
array frames = mlx::core::fft::irfftn(stft_matrix, {n_fft}, {-1}, s);
|
||||||
|
|
||||||
frames = multiply(frames, win, s);
|
frames = multiply(frames, win, s);
|
||||||
|
|
||||||
int signal_length = (frames.shape(0) - 1) * hop_length + n_fft;
|
int signal_length = (frames.shape(0) - 1) * hop_length + n_fft;
|
||||||
array signal = zeros({signal_length}, float32, s);
|
array signal = zeros({signal_length}, float32, s);
|
||||||
array window_sum = zeros({signal_length}, float32, s);
|
array window_sum = zeros({signal_length}, float32, s);
|
||||||
|
|
||||||
for (int i = 0; i < frames.shape(0); ++i) {
|
Shape strided_shape = {frames.shape(0), n_fft};
|
||||||
array frame = reshape(slice(frames, {i}, {i + 1}, s), {n_fft}, s);
|
Strides strided_strides = {
|
||||||
array signal_slice =
|
hop_length * static_cast<long long>(sizeof(float32)),
|
||||||
slice(signal, {i * hop_length}, {i * hop_length + n_fft}, s);
|
static_cast<long long>(sizeof(float32))};
|
||||||
array window_slice =
|
array signal_strided =
|
||||||
slice(window_sum, {i * hop_length}, {i * hop_length + n_fft}, s);
|
as_strided(signal, strided_shape, strided_strides, 0, s);
|
||||||
|
array window_sum_strided =
|
||||||
|
as_strided(window_sum, strided_shape, strided_strides, 0, s);
|
||||||
|
|
||||||
signal_slice = add(signal_slice, frame, s);
|
signal_strided = add(signal_strided, frames, s);
|
||||||
window_slice = add(window_slice, win, s);
|
window_sum_strided = add(window_sum_strided, win, s);
|
||||||
}
|
|
||||||
|
|
||||||
signal = divide(signal, window_sum, s);
|
signal = divide(signal, window_sum, s);
|
||||||
|
|
||||||
|
@ -308,3 +308,77 @@ TEST_CASE("test fft grads") {
|
|||||||
.second;
|
.second;
|
||||||
CHECK_EQ(vjp_out.shape(), Shape{5, 5});
|
CHECK_EQ(vjp_out.shape(), Shape{5, 5});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test stft and istft") {
|
||||||
|
int n_fft = 4;
|
||||||
|
int hop_length = 2;
|
||||||
|
int win_length = 4;
|
||||||
|
|
||||||
|
array signal = array({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, float32);
|
||||||
|
|
||||||
|
array window = array({0.5, 1.0, 1.0, 0.5}, float32);
|
||||||
|
|
||||||
|
SUBCASE("stft basic functionality") {
|
||||||
|
auto stft_result = fft::stft(signal, n_fft, hop_length, win_length, window);
|
||||||
|
|
||||||
|
CHECK_EQ(stft_result.shape(0), 4);
|
||||||
|
CHECK_EQ(stft_result.shape(1), 3);
|
||||||
|
|
||||||
|
CHECK_EQ(stft_result.dtype(), complex64);
|
||||||
|
}
|
||||||
|
|
||||||
|
SUBCASE("istft reconstruction") {
|
||||||
|
auto stft_result = fft::stft(signal, n_fft, hop_length, win_length, window);
|
||||||
|
auto reconstructed_signal =
|
||||||
|
fft::istft(stft_result, hop_length, win_length, window);
|
||||||
|
|
||||||
|
CHECK_EQ(reconstructed_signal.shape(0), signal.shape(0));
|
||||||
|
CHECK(allclose(signal, reconstructed_signal, 1e-5, 1e-5).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
SUBCASE("stft with default parameters") {
|
||||||
|
auto stft_result = fft::stft(signal);
|
||||||
|
|
||||||
|
CHECK_EQ(stft_result.shape(0), 5);
|
||||||
|
CHECK_EQ(stft_result.shape(1), 3);
|
||||||
|
|
||||||
|
CHECK_EQ(stft_result.dtype(), complex64);
|
||||||
|
}
|
||||||
|
|
||||||
|
SUBCASE("istft with length parameter") {
|
||||||
|
auto stft_result = fft::stft(signal, n_fft, hop_length, win_length, window);
|
||||||
|
int length = 6;
|
||||||
|
auto reconstructed_signal =
|
||||||
|
fft::istft(stft_result, hop_length, win_length, window, true, length);
|
||||||
|
|
||||||
|
CHECK_EQ(reconstructed_signal.shape(0), length);
|
||||||
|
|
||||||
|
CHECK(
|
||||||
|
allclose(slice(signal, {0}, {length}), reconstructed_signal, 1e-5, 1e-5)
|
||||||
|
.item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
SUBCASE("stft and istft with normalization") {
|
||||||
|
auto stft_result = fft::stft(
|
||||||
|
signal, n_fft, hop_length, win_length, window, true, "reflect", true);
|
||||||
|
auto reconstructed_signal =
|
||||||
|
fft::istft(stft_result, hop_length, win_length, window, true, -1, true);
|
||||||
|
|
||||||
|
CHECK(allclose(signal, reconstructed_signal, 1e-5, 1e-5).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
SUBCASE("stft with onesided=False") {
|
||||||
|
auto stft_result = fft::stft(
|
||||||
|
signal,
|
||||||
|
n_fft,
|
||||||
|
hop_length,
|
||||||
|
win_length,
|
||||||
|
window,
|
||||||
|
true,
|
||||||
|
"reflect",
|
||||||
|
false,
|
||||||
|
false);
|
||||||
|
|
||||||
|
CHECK_EQ(stft_result.shape(1), n_fft);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user