Fixx rfft odd grad and add tests

This commit is contained in:
Angelos Katharopoulos
2025-05-12 22:27:12 -07:00
parent e1c65e1381
commit 194f1adbd8
3 changed files with 90 additions and 32 deletions

View File

@@ -1149,7 +1149,7 @@ TEST_CASE("test complex gradients") {
auto cotan = array(complex64_t{2.0, 3.0});
out = vjp([x](array a) { return multiply(a, x); }, y, cotan).second;
CHECK_EQ(out.dtype(), float32);
CHECK_EQ(out.item<float>(), -8.0);
CHECK_EQ(out.item<float>(), 16.0);
out = vjp([y](array a) { return multiply(a, y); }, x, cotan).second;
CHECK_EQ(out.item<complex64_t>(), complex64_t{6.0, 9.0});