Fix imag() vjp (#2367)

This commit is contained in:
Angelos Katharopoulos
2025-07-14 13:11:16 -07:00
committed by GitHub
parent 2d3c26c565
commit 5201df5030
2 changed files with 10 additions and 6 deletions

View File

@@ -606,7 +606,7 @@ class TestAutograd(mlx_tests.MLXTestCase):
x = mx.array([0.0 + 1j, 1.0 + 0.0j, 0.5 + 0.5j])
dfdx = mx.grad(fun)(x)
self.assertTrue(mx.allclose(dfdx, -2j * mx.ones_like(x)))
self.assertTrue(mx.allclose(dfdx, 2j * mx.ones_like(x)))
def test_flatten_unflatten_vjps(self):
def fun(x):