diff --git a/tests/fft_tests.cpp b/tests/fft_tests.cpp index 0db3999c8d..b9e2d1bcc0 100644 --- a/tests/fft_tests.cpp +++ b/tests/fft_tests.cpp @@ -243,7 +243,7 @@ TEST_CASE("test fft grads") { auto fft_fn = [](array x) { return fft::fft(x); }; auto cotangent = astype(arange(10), complex64); auto vjp_out = vjp(fft_fn, zeros_like(cotangent), cotangent).second; - CHECK(array_equal(fft::fft(cotangent), vjp_out).item()); + CHECK(array_equal(fft::ifft(cotangent) * 10, vjp_out).item()); auto tangent = astype(arange(10), complex64); auto jvp_out = jvp(fft_fn, zeros_like(tangent), tangent).second; @@ -252,7 +252,7 @@ TEST_CASE("test fft grads") { // Inverse auto ifft_fn = [](array x) { return fft::ifft(x); }; vjp_out = vjp(ifft_fn, zeros_like(cotangent), cotangent).second; - CHECK(array_equal(fft::ifft(cotangent), vjp_out).item()); + CHECK(array_equal(fft::fft(cotangent) * 0.1, vjp_out).item()); jvp_out = jvp(ifft_fn, zeros_like(tangent), tangent).second; CHECK(array_equal(fft::ifft(tangent), jvp_out).item()); @@ -261,7 +261,8 @@ TEST_CASE("test fft grads") { auto rfft_fn = [](array x) { return fft::rfft(x); }; cotangent = astype(arange(6), complex64); vjp_out = vjp(rfft_fn, zeros({10}), cotangent).second; - auto expected = astype(fft::fft(cotangent, 10, 0), float32); + array mask({1.0, 0.5, 0.5, 0.5, 0.5, 1.0}, complex64); + auto expected = fft::irfft(cotangent * mask, 10, 0) * 10; CHECK(array_equal(expected, vjp_out).item()); tangent = astype(arange(10), float32); @@ -272,12 +273,9 @@ TEST_CASE("test fft grads") { auto irfft_fn = [](array x) { return fft::irfft(x); }; cotangent = astype(arange(10), float32); vjp_out = vjp(irfft_fn, astype(zeros({6}), complex64), cotangent).second; - expected = fft::fft(cotangent, 10, 0); - auto o_splits = split(vjp_out, {1, 5}); - auto e_splits = split(expected, {1, 5, 6}); - CHECK_EQ(e_splits[0].item(), o_splits[0].item()); - CHECK(array_equal(2 * e_splits[1], o_splits[1]).item()); - CHECK_EQ(e_splits[2].item(), o_splits[2].item()); + mask = array({0.1, 0.2, 0.2, 0.2, 0.2, 0.1}, float32); + expected = fft::rfft(cotangent) * mask; + CHECK(array_equal(expected, vjp_out).item()); tangent = astype(arange(10), complex64); jvp_out = jvp(irfft_fn, zeros_like(tangent), tangent).second;