Fix eval in trace bugs (#612)

* Fix eval in trace bugs

* comment nit
This commit is contained in:
Awni Hannun
2024-02-02 09:57:12 -08:00
committed by GitHub
parent 506d43035c
commit cb6156d35d
4 changed files with 33 additions and 12 deletions

View File

@@ -393,6 +393,28 @@ class TestAutograd(mlx_tests.MLXTestCase):
g = mx.grad(lambda x: x**2)(x)
self.assertAlmostEqual(g.item(), 4.0)
def test_eval_in_grad(self):
arr = mx.array([1.0])
cotan = mx.array([1.0, 1.0])
y = mx.array([2.0, 2.0])
def func(x):
x = x + y
cond = x < 1
cond.tolist()
return x**2
_, vjps = mx.vjp(func, (arr,), (cotan,))
self.assertEqual(vjps[0].item(), 12.0)
def func(x):
x = x + mx.array([1.0, 1.0])
mx.eval(x)
return x**2
_, vjps = mx.vjp(func, (arr,), (cotan,))
self.assertEqual(vjps[0].item(), 8.0)
if __name__ == "__main__":
unittest.main()