mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-10 21:37:50 +08:00
@@ -393,6 +393,28 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
g = mx.grad(lambda x: x**2)(x)
|
||||
self.assertAlmostEqual(g.item(), 4.0)
|
||||
|
||||
def test_eval_in_grad(self):
|
||||
arr = mx.array([1.0])
|
||||
cotan = mx.array([1.0, 1.0])
|
||||
y = mx.array([2.0, 2.0])
|
||||
|
||||
def func(x):
|
||||
x = x + y
|
||||
cond = x < 1
|
||||
cond.tolist()
|
||||
return x**2
|
||||
|
||||
_, vjps = mx.vjp(func, (arr,), (cotan,))
|
||||
self.assertEqual(vjps[0].item(), 12.0)
|
||||
|
||||
def func(x):
|
||||
x = x + mx.array([1.0, 1.0])
|
||||
mx.eval(x)
|
||||
return x**2
|
||||
|
||||
_, vjps = mx.vjp(func, (arr,), (cotan,))
|
||||
self.assertEqual(vjps[0].item(), 8.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user