fix grad with inplace updates (#1961)

This commit is contained in:
Awni Hannun
2025-03-13 19:13:09 -07:00
committed by GitHub
parent d2a94f9e6a
commit 2770a10240
2 changed files with 26 additions and 1 deletions

View File

@@ -746,6 +746,7 @@ class TestAutograd(mlx_tests.MLXTestCase):
mx.checkpoint,
]:
if mx.metal.is_available():
mx.synchronize(mx.default_stream(mx.default_device()))
mem_pre = mx.metal.get_active_memory()
else:
mem_pre = 0
@@ -790,6 +791,20 @@ class TestAutograd(mlx_tests.MLXTestCase):
mx.grad(fun)(arrs)
self.assertEqual(init_id, id(arrs[0]))
def test_grad_with_inplace_update(self):
def loss_fn(model):
model[1] = mx.array(2.0)
return model[0]
model = [
mx.array(0.0),
mx.array(1.0),
]
grad_fn = mx.grad(loss_fn)
grad_fn(model)
self.assertEqual(model[1].item(), 2.0)
if __name__ == "__main__":
unittest.main()