mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fixx rfft odd grad and add tests
This commit is contained in:
@@ -1149,7 +1149,7 @@ TEST_CASE("test complex gradients") {
|
||||
auto cotan = array(complex64_t{2.0, 3.0});
|
||||
out = vjp([x](array a) { return multiply(a, x); }, y, cotan).second;
|
||||
CHECK_EQ(out.dtype(), float32);
|
||||
CHECK_EQ(out.item<float>(), -8.0);
|
||||
CHECK_EQ(out.item<float>(), 16.0);
|
||||
|
||||
out = vjp([y](array a) { return multiply(a, y); }, x, cotan).second;
|
||||
CHECK_EQ(out.item<complex64_t>(), complex64_t{6.0, 9.0});
|
||||
|
||||
Reference in New Issue
Block a user