diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 8585bd378..8d78a1bde 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -171,12 +171,12 @@ auto py_value_and_grad( nb::list tree; tree.append(args); tree.append(kwargs); - tree_replace(tree, arrays, a); + tree_fill(tree, a); // Call the python function py_value_out = fun(*tree[0], **tree[1]); - tree_replace(tree, arrays, a); + tree_fill(tree, arrays); // Validate the return value of the python function if (!nb::isinstance(py_value_out)) { diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 3ec020270..ffccd85fc 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -770,6 +770,26 @@ class TestAutograd(mlx_tests.MLXTestCase): self.assertEqual(mem_pre, mem_post) + def test_grad_with_copies(self): + a = mx.array(2.0) + arrays = [a, a, a] + + def fun(arrays): + return arrays[0] + arrays[2] + + grads = mx.grad(fun)(arrays) + self.assertEqual(grads[0].item(), 1.0) + self.assertEqual(grads[2].item(), 1.0) + + def test_grad_ids_pre_post(self): + def fun(arrs): + return arrs[0] + + arrs = [mx.array(1.0)] + init_id = id(arrs[0]) + mx.grad(fun)(arrs) + self.assertEqual(init_id, id(arrs[0])) + if __name__ == "__main__": unittest.main()