This commit is contained in:
Param Thakkar
2025-07-26 19:00:49 +09:00
committed by GitHub
5 changed files with 338 additions and 3 deletions

View File

@@ -307,6 +307,80 @@ TEST_CASE("test fft grads") {
CHECK_EQ(vjp_out.shape(), Shape{5, 5});
}
TEST_CASE("test stft and istft") {
int n_fft = 4;
int hop_length = 2;
int win_length = 4;
array signal = array({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, float32);
array window = array({0.5, 1.0, 1.0, 0.5}, float32);
SUBCASE("stft basic functionality") {
auto stft_result = fft::stft(signal, n_fft, hop_length, win_length, window);
CHECK_EQ(stft_result.shape(0), 4);
CHECK_EQ(stft_result.shape(1), 3);
CHECK_EQ(stft_result.dtype(), complex64);
}
SUBCASE("istft reconstruction") {
auto stft_result = fft::stft(signal, n_fft, hop_length, win_length, window);
auto reconstructed_signal =
fft::istft(stft_result, hop_length, win_length, window);
CHECK_EQ(reconstructed_signal.shape(0), signal.shape(0));
CHECK(allclose(signal, reconstructed_signal, 1e-5, 1e-5).item<bool>());
}
SUBCASE("stft with default parameters") {
auto stft_result = fft::stft(signal);
CHECK_EQ(stft_result.shape(0), 5);
CHECK_EQ(stft_result.shape(1), 3);
CHECK_EQ(stft_result.dtype(), complex64);
}
SUBCASE("istft with length parameter") {
auto stft_result = fft::stft(signal, n_fft, hop_length, win_length, window);
int length = 6;
auto reconstructed_signal =
fft::istft(stft_result, hop_length, win_length, window, true, length);
CHECK_EQ(reconstructed_signal.shape(0), length);
CHECK(
allclose(slice(signal, {0}, {length}), reconstructed_signal, 1e-5, 1e-5)
.item<bool>());
}
SUBCASE("stft and istft with normalization") {
auto stft_result = fft::stft(
signal, n_fft, hop_length, win_length, window, true, "reflect", true);
auto reconstructed_signal =
fft::istft(stft_result, hop_length, win_length, window, true, -1, true);
CHECK(allclose(signal, reconstructed_signal, 1e-5, 1e-5).item<bool>());
}
SUBCASE("stft with onesided=False") {
auto stft_result = fft::stft(
signal,
n_fft,
hop_length,
win_length,
window,
true,
"reflect",
false,
false);
CHECK_EQ(stft_result.shape(1), n_fft);
}
}
TEST_CASE("test fftshift and ifftshift") {
// Test 1D array with even length
auto x = arange(8);