mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fixed fft.cpp and added tests
This commit is contained in:
@@ -308,3 +308,77 @@ TEST_CASE("test fft grads") {
|
||||
.second;
|
||||
CHECK_EQ(vjp_out.shape(), Shape{5, 5});
|
||||
}
|
||||
|
||||
TEST_CASE("test stft and istft") {
|
||||
int n_fft = 4;
|
||||
int hop_length = 2;
|
||||
int win_length = 4;
|
||||
|
||||
array signal = array({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, float32);
|
||||
|
||||
array window = array({0.5, 1.0, 1.0, 0.5}, float32);
|
||||
|
||||
SUBCASE("stft basic functionality") {
|
||||
auto stft_result = fft::stft(signal, n_fft, hop_length, win_length, window);
|
||||
|
||||
CHECK_EQ(stft_result.shape(0), 4);
|
||||
CHECK_EQ(stft_result.shape(1), 3);
|
||||
|
||||
CHECK_EQ(stft_result.dtype(), complex64);
|
||||
}
|
||||
|
||||
SUBCASE("istft reconstruction") {
|
||||
auto stft_result = fft::stft(signal, n_fft, hop_length, win_length, window);
|
||||
auto reconstructed_signal =
|
||||
fft::istft(stft_result, hop_length, win_length, window);
|
||||
|
||||
CHECK_EQ(reconstructed_signal.shape(0), signal.shape(0));
|
||||
CHECK(allclose(signal, reconstructed_signal, 1e-5, 1e-5).item<bool>());
|
||||
}
|
||||
|
||||
SUBCASE("stft with default parameters") {
|
||||
auto stft_result = fft::stft(signal);
|
||||
|
||||
CHECK_EQ(stft_result.shape(0), 5);
|
||||
CHECK_EQ(stft_result.shape(1), 3);
|
||||
|
||||
CHECK_EQ(stft_result.dtype(), complex64);
|
||||
}
|
||||
|
||||
SUBCASE("istft with length parameter") {
|
||||
auto stft_result = fft::stft(signal, n_fft, hop_length, win_length, window);
|
||||
int length = 6;
|
||||
auto reconstructed_signal =
|
||||
fft::istft(stft_result, hop_length, win_length, window, true, length);
|
||||
|
||||
CHECK_EQ(reconstructed_signal.shape(0), length);
|
||||
|
||||
CHECK(
|
||||
allclose(slice(signal, {0}, {length}), reconstructed_signal, 1e-5, 1e-5)
|
||||
.item<bool>());
|
||||
}
|
||||
|
||||
SUBCASE("stft and istft with normalization") {
|
||||
auto stft_result = fft::stft(
|
||||
signal, n_fft, hop_length, win_length, window, true, "reflect", true);
|
||||
auto reconstructed_signal =
|
||||
fft::istft(stft_result, hop_length, win_length, window, true, -1, true);
|
||||
|
||||
CHECK(allclose(signal, reconstructed_signal, 1e-5, 1e-5).item<bool>());
|
||||
}
|
||||
|
||||
SUBCASE("stft with onesided=False") {
|
||||
auto stft_result = fft::stft(
|
||||
signal,
|
||||
n_fft,
|
||||
hop_length,
|
||||
win_length,
|
||||
window,
|
||||
true,
|
||||
"reflect",
|
||||
false,
|
||||
false);
|
||||
|
||||
CHECK_EQ(stft_result.shape(1), n_fft);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user