Fix the fft tests in C++

This commit is contained in:
Angelos Katharopoulos 2025-05-12 22:52:35 -07:00
parent 194f1adbd8
commit f93cda7a1c

View File

@ -243,7 +243,7 @@ TEST_CASE("test fft grads") {
auto fft_fn = [](array x) { return fft::fft(x); }; auto fft_fn = [](array x) { return fft::fft(x); };
auto cotangent = astype(arange(10), complex64); auto cotangent = astype(arange(10), complex64);
auto vjp_out = vjp(fft_fn, zeros_like(cotangent), cotangent).second; auto vjp_out = vjp(fft_fn, zeros_like(cotangent), cotangent).second;
CHECK(array_equal(fft::fft(cotangent), vjp_out).item<bool>()); CHECK(array_equal(fft::ifft(cotangent) * 10, vjp_out).item<bool>());
auto tangent = astype(arange(10), complex64); auto tangent = astype(arange(10), complex64);
auto jvp_out = jvp(fft_fn, zeros_like(tangent), tangent).second; auto jvp_out = jvp(fft_fn, zeros_like(tangent), tangent).second;
@ -252,7 +252,7 @@ TEST_CASE("test fft grads") {
// Inverse // Inverse
auto ifft_fn = [](array x) { return fft::ifft(x); }; auto ifft_fn = [](array x) { return fft::ifft(x); };
vjp_out = vjp(ifft_fn, zeros_like(cotangent), cotangent).second; vjp_out = vjp(ifft_fn, zeros_like(cotangent), cotangent).second;
CHECK(array_equal(fft::ifft(cotangent), vjp_out).item<bool>()); CHECK(array_equal(fft::fft(cotangent) * 0.1, vjp_out).item<bool>());
jvp_out = jvp(ifft_fn, zeros_like(tangent), tangent).second; jvp_out = jvp(ifft_fn, zeros_like(tangent), tangent).second;
CHECK(array_equal(fft::ifft(tangent), jvp_out).item<bool>()); CHECK(array_equal(fft::ifft(tangent), jvp_out).item<bool>());
@ -261,7 +261,8 @@ TEST_CASE("test fft grads") {
auto rfft_fn = [](array x) { return fft::rfft(x); }; auto rfft_fn = [](array x) { return fft::rfft(x); };
cotangent = astype(arange(6), complex64); cotangent = astype(arange(6), complex64);
vjp_out = vjp(rfft_fn, zeros({10}), cotangent).second; vjp_out = vjp(rfft_fn, zeros({10}), cotangent).second;
auto expected = astype(fft::fft(cotangent, 10, 0), float32); array mask({1.0, 0.5, 0.5, 0.5, 0.5, 1.0}, complex64);
auto expected = fft::irfft(cotangent * mask, 10, 0) * 10;
CHECK(array_equal(expected, vjp_out).item<bool>()); CHECK(array_equal(expected, vjp_out).item<bool>());
tangent = astype(arange(10), float32); tangent = astype(arange(10), float32);
@ -272,12 +273,9 @@ TEST_CASE("test fft grads") {
auto irfft_fn = [](array x) { return fft::irfft(x); }; auto irfft_fn = [](array x) { return fft::irfft(x); };
cotangent = astype(arange(10), float32); cotangent = astype(arange(10), float32);
vjp_out = vjp(irfft_fn, astype(zeros({6}), complex64), cotangent).second; vjp_out = vjp(irfft_fn, astype(zeros({6}), complex64), cotangent).second;
expected = fft::fft(cotangent, 10, 0); mask = array({0.1, 0.2, 0.2, 0.2, 0.2, 0.1}, float32);
auto o_splits = split(vjp_out, {1, 5}); expected = fft::rfft(cotangent) * mask;
auto e_splits = split(expected, {1, 5, 6}); CHECK(array_equal(expected, vjp_out).item<bool>());
CHECK_EQ(e_splits[0].item<complex64_t>(), o_splits[0].item<complex64_t>());
CHECK(array_equal(2 * e_splits[1], o_splits[1]).item<bool>());
CHECK_EQ(e_splits[2].item<complex64_t>(), o_splits[2].item<complex64_t>());
tangent = astype(arange(10), complex64); tangent = astype(arange(10), complex64);
jvp_out = jvp(irfft_fn, zeros_like(tangent), tangent).second; jvp_out = jvp(irfft_fn, zeros_like(tangent), tangent).second;