mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 07:18:29 +08:00
Fix imag() vjp (#2367)
This commit is contained in:

committed by
GitHub

parent
2d3c26c565
commit
5201df5030
@@ -606,7 +606,7 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
|
||||
x = mx.array([0.0 + 1j, 1.0 + 0.0j, 0.5 + 0.5j])
|
||||
dfdx = mx.grad(fun)(x)
|
||||
self.assertTrue(mx.allclose(dfdx, -2j * mx.ones_like(x)))
|
||||
self.assertTrue(mx.allclose(dfdx, 2j * mx.ones_like(x)))
|
||||
|
||||
def test_flatten_unflatten_vjps(self):
|
||||
def fun(x):
|
||||
|
Reference in New Issue
Block a user