mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-11 19:56:40 +08:00
Fix the fft tests in C++
This commit is contained in:
parent
194f1adbd8
commit
f93cda7a1c
@ -243,7 +243,7 @@ TEST_CASE("test fft grads") {
|
||||
auto fft_fn = [](array x) { return fft::fft(x); };
|
||||
auto cotangent = astype(arange(10), complex64);
|
||||
auto vjp_out = vjp(fft_fn, zeros_like(cotangent), cotangent).second;
|
||||
CHECK(array_equal(fft::fft(cotangent), vjp_out).item<bool>());
|
||||
CHECK(array_equal(fft::ifft(cotangent) * 10, vjp_out).item<bool>());
|
||||
|
||||
auto tangent = astype(arange(10), complex64);
|
||||
auto jvp_out = jvp(fft_fn, zeros_like(tangent), tangent).second;
|
||||
@ -252,7 +252,7 @@ TEST_CASE("test fft grads") {
|
||||
// Inverse
|
||||
auto ifft_fn = [](array x) { return fft::ifft(x); };
|
||||
vjp_out = vjp(ifft_fn, zeros_like(cotangent), cotangent).second;
|
||||
CHECK(array_equal(fft::ifft(cotangent), vjp_out).item<bool>());
|
||||
CHECK(array_equal(fft::fft(cotangent) * 0.1, vjp_out).item<bool>());
|
||||
|
||||
jvp_out = jvp(ifft_fn, zeros_like(tangent), tangent).second;
|
||||
CHECK(array_equal(fft::ifft(tangent), jvp_out).item<bool>());
|
||||
@ -261,7 +261,8 @@ TEST_CASE("test fft grads") {
|
||||
auto rfft_fn = [](array x) { return fft::rfft(x); };
|
||||
cotangent = astype(arange(6), complex64);
|
||||
vjp_out = vjp(rfft_fn, zeros({10}), cotangent).second;
|
||||
auto expected = astype(fft::fft(cotangent, 10, 0), float32);
|
||||
array mask({1.0, 0.5, 0.5, 0.5, 0.5, 1.0}, complex64);
|
||||
auto expected = fft::irfft(cotangent * mask, 10, 0) * 10;
|
||||
CHECK(array_equal(expected, vjp_out).item<bool>());
|
||||
|
||||
tangent = astype(arange(10), float32);
|
||||
@ -272,12 +273,9 @@ TEST_CASE("test fft grads") {
|
||||
auto irfft_fn = [](array x) { return fft::irfft(x); };
|
||||
cotangent = astype(arange(10), float32);
|
||||
vjp_out = vjp(irfft_fn, astype(zeros({6}), complex64), cotangent).second;
|
||||
expected = fft::fft(cotangent, 10, 0);
|
||||
auto o_splits = split(vjp_out, {1, 5});
|
||||
auto e_splits = split(expected, {1, 5, 6});
|
||||
CHECK_EQ(e_splits[0].item<complex64_t>(), o_splits[0].item<complex64_t>());
|
||||
CHECK(array_equal(2 * e_splits[1], o_splits[1]).item<bool>());
|
||||
CHECK_EQ(e_splits[2].item<complex64_t>(), o_splits[2].item<complex64_t>());
|
||||
mask = array({0.1, 0.2, 0.2, 0.2, 0.2, 0.1}, float32);
|
||||
expected = fft::rfft(cotangent) * mask;
|
||||
CHECK(array_equal(expected, vjp_out).item<bool>());
|
||||
|
||||
tangent = astype(arange(10), complex64);
|
||||
jvp_out = jvp(irfft_fn, zeros_like(tangent), tangent).second;
|
||||
|
Loading…
Reference in New Issue
Block a user