Fix the fft tests in C++

This commit is contained in:
Angelos Katharopoulos 2025-05-12 22:52:35 -07:00
parent 194f1adbd8
commit f93cda7a1c

View File

@ -243,7 +243,7 @@ TEST_CASE("test fft grads") {
auto fft_fn = [](array x) { return fft::fft(x); };
auto cotangent = astype(arange(10), complex64);
auto vjp_out = vjp(fft_fn, zeros_like(cotangent), cotangent).second;
CHECK(array_equal(fft::fft(cotangent), vjp_out).item<bool>());
CHECK(array_equal(fft::ifft(cotangent) * 10, vjp_out).item<bool>());
auto tangent = astype(arange(10), complex64);
auto jvp_out = jvp(fft_fn, zeros_like(tangent), tangent).second;
@ -252,7 +252,7 @@ TEST_CASE("test fft grads") {
// Inverse
auto ifft_fn = [](array x) { return fft::ifft(x); };
vjp_out = vjp(ifft_fn, zeros_like(cotangent), cotangent).second;
CHECK(array_equal(fft::ifft(cotangent), vjp_out).item<bool>());
CHECK(array_equal(fft::fft(cotangent) * 0.1, vjp_out).item<bool>());
jvp_out = jvp(ifft_fn, zeros_like(tangent), tangent).second;
CHECK(array_equal(fft::ifft(tangent), jvp_out).item<bool>());
@ -261,7 +261,8 @@ TEST_CASE("test fft grads") {
auto rfft_fn = [](array x) { return fft::rfft(x); };
cotangent = astype(arange(6), complex64);
vjp_out = vjp(rfft_fn, zeros({10}), cotangent).second;
auto expected = astype(fft::fft(cotangent, 10, 0), float32);
array mask({1.0, 0.5, 0.5, 0.5, 0.5, 1.0}, complex64);
auto expected = fft::irfft(cotangent * mask, 10, 0) * 10;
CHECK(array_equal(expected, vjp_out).item<bool>());
tangent = astype(arange(10), float32);
@ -272,12 +273,9 @@ TEST_CASE("test fft grads") {
auto irfft_fn = [](array x) { return fft::irfft(x); };
cotangent = astype(arange(10), float32);
vjp_out = vjp(irfft_fn, astype(zeros({6}), complex64), cotangent).second;
expected = fft::fft(cotangent, 10, 0);
auto o_splits = split(vjp_out, {1, 5});
auto e_splits = split(expected, {1, 5, 6});
CHECK_EQ(e_splits[0].item<complex64_t>(), o_splits[0].item<complex64_t>());
CHECK(array_equal(2 * e_splits[1], o_splits[1]).item<bool>());
CHECK_EQ(e_splits[2].item<complex64_t>(), o_splits[2].item<complex64_t>());
mask = array({0.1, 0.2, 0.2, 0.2, 0.2, 0.1}, float32);
expected = fft::rfft(cotangent) * mask;
CHECK(array_equal(expected, vjp_out).item<bool>());
tangent = astype(arange(10), complex64);
jvp_out = jvp(irfft_fn, zeros_like(tangent), tangent).second;