From 2770a1024082eb10cce6bc0ac589ad089e7be611 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 13 Mar 2025 19:13:09 -0700 Subject: [PATCH] fix grad with inplace updates (#1961) --- python/src/transforms.cpp | 12 +++++++++++- python/tests/test_autograd.py | 15 +++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 8d78a1bdec..4a5e2e6ac8 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -176,7 +176,17 @@ auto py_value_and_grad( // Call the python function py_value_out = fun(*tree[0], **tree[1]); - tree_fill(tree, arrays); + // Replace the tracers with the originals. Don't overwrite + // locations which were written to during the call to fun + int index = 0; + tree_visit_update(tree, [&](nb::handle node) { + auto replace_arr = nb::cast(node); + if (replace_arr.id() == a[index].id()) { + return nb::cast(arrays[index++]); + } else { + return nb::cast(replace_arr); + } + }); // Validate the return value of the python function if (!nb::isinstance(py_value_out)) { diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index ffccd85fc2..350b098370 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -746,6 +746,7 @@ class TestAutograd(mlx_tests.MLXTestCase): mx.checkpoint, ]: if mx.metal.is_available(): + mx.synchronize(mx.default_stream(mx.default_device())) mem_pre = mx.metal.get_active_memory() else: mem_pre = 0 @@ -790,6 +791,20 @@ class TestAutograd(mlx_tests.MLXTestCase): mx.grad(fun)(arrs) self.assertEqual(init_id, id(arrs[0])) + def test_grad_with_inplace_update(self): + def loss_fn(model): + model[1] = mx.array(2.0) + return model[0] + + model = [ + mx.array(0.0), + mx.array(1.0), + ] + + grad_fn = mx.grad(loss_fn) + grad_fn(model) + self.assertEqual(model[1].item(), 2.0) + if __name__ == "__main__": unittest.main()