diff --git a/mlx/fft.cpp b/mlx/fft.cpp index a8af2eea8..68b032727 100644 --- a/mlx/fft.cpp +++ b/mlx/fft.cpp @@ -225,15 +225,13 @@ array stft( int n_frames = 1 + (padded_x.shape(0) - n_fft) / hop_length; - std::vector frames; - for (int i = 0; i < n_frames; ++i) { - array frame = - slice(padded_x, {i * hop_length}, {i * hop_length + n_fft}, s); - frames.push_back(multiply(frame, win, s)); - } - - array stacked_frames = stack(frames, 0, s); + Shape strided_shape = {n_frames, n_fft}; + Strides strided_strides = { + hop_length * static_cast(sizeof(float32)), + static_cast(sizeof(float32))}; + array frames = as_strided(padded_x, strided_shape, strided_strides, 0, s); + array stacked_frames = multiply(frames, win, s); array stft_result = mlx::core::fft::rfftn(stacked_frames, {n_fft}, {-1}, s); if (normalized) { @@ -273,23 +271,23 @@ array istft( } array frames = mlx::core::fft::irfftn(stft_matrix, {n_fft}, {-1}, s); - frames = multiply(frames, win, s); int signal_length = (frames.shape(0) - 1) * hop_length + n_fft; array signal = zeros({signal_length}, float32, s); array window_sum = zeros({signal_length}, float32, s); - for (int i = 0; i < frames.shape(0); ++i) { - array frame = reshape(slice(frames, {i}, {i + 1}, s), {n_fft}, s); - array signal_slice = - slice(signal, {i * hop_length}, {i * hop_length + n_fft}, s); - array window_slice = - slice(window_sum, {i * hop_length}, {i * hop_length + n_fft}, s); + Shape strided_shape = {frames.shape(0), n_fft}; + Strides strided_strides = { + hop_length * static_cast(sizeof(float32)), + static_cast(sizeof(float32))}; + array signal_strided = + as_strided(signal, strided_shape, strided_strides, 0, s); + array window_sum_strided = + as_strided(window_sum, strided_shape, strided_strides, 0, s); - signal_slice = add(signal_slice, frame, s); - window_slice = add(window_slice, win, s); - } + signal_strided = add(signal_strided, frames, s); + window_sum_strided = add(window_sum_strided, win, s); signal = divide(signal, window_sum, s); diff --git a/tests/fft_tests.cpp b/tests/fft_tests.cpp index c04dda1d5..b0d8c8e52 100644 --- a/tests/fft_tests.cpp +++ b/tests/fft_tests.cpp @@ -308,3 +308,77 @@ TEST_CASE("test fft grads") { .second; CHECK_EQ(vjp_out.shape(), Shape{5, 5}); } + +TEST_CASE("test stft and istft") { + int n_fft = 4; + int hop_length = 2; + int win_length = 4; + + array signal = array({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, float32); + + array window = array({0.5, 1.0, 1.0, 0.5}, float32); + + SUBCASE("stft basic functionality") { + auto stft_result = fft::stft(signal, n_fft, hop_length, win_length, window); + + CHECK_EQ(stft_result.shape(0), 4); + CHECK_EQ(stft_result.shape(1), 3); + + CHECK_EQ(stft_result.dtype(), complex64); + } + + SUBCASE("istft reconstruction") { + auto stft_result = fft::stft(signal, n_fft, hop_length, win_length, window); + auto reconstructed_signal = + fft::istft(stft_result, hop_length, win_length, window); + + CHECK_EQ(reconstructed_signal.shape(0), signal.shape(0)); + CHECK(allclose(signal, reconstructed_signal, 1e-5, 1e-5).item()); + } + + SUBCASE("stft with default parameters") { + auto stft_result = fft::stft(signal); + + CHECK_EQ(stft_result.shape(0), 5); + CHECK_EQ(stft_result.shape(1), 3); + + CHECK_EQ(stft_result.dtype(), complex64); + } + + SUBCASE("istft with length parameter") { + auto stft_result = fft::stft(signal, n_fft, hop_length, win_length, window); + int length = 6; + auto reconstructed_signal = + fft::istft(stft_result, hop_length, win_length, window, true, length); + + CHECK_EQ(reconstructed_signal.shape(0), length); + + CHECK( + allclose(slice(signal, {0}, {length}), reconstructed_signal, 1e-5, 1e-5) + .item()); + } + + SUBCASE("stft and istft with normalization") { + auto stft_result = fft::stft( + signal, n_fft, hop_length, win_length, window, true, "reflect", true); + auto reconstructed_signal = + fft::istft(stft_result, hop_length, win_length, window, true, -1, true); + + CHECK(allclose(signal, reconstructed_signal, 1e-5, 1e-5).item()); + } + + SUBCASE("stft with onesided=False") { + auto stft_result = fft::stft( + signal, + n_fft, + hop_length, + win_length, + window, + true, + "reflect", + false, + false); + + CHECK_EQ(stft_result.shape(1), n_fft); + } +} \ No newline at end of file