Fix grad copies (#1854)

* fix grad with copies

* add test

* add test
This commit is contained in:
Awni Hannun 2025-02-11 15:26:42 -08:00 committed by GitHub
parent 2a45056ba8
commit 0a5215693e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 2 deletions

View File

@ -171,12 +171,12 @@ auto py_value_and_grad(
nb::list tree;
tree.append(args);
tree.append(kwargs);
tree_replace(tree, arrays, a);
tree_fill(tree, a);
// Call the python function
py_value_out = fun(*tree[0], **tree[1]);
tree_replace(tree, arrays, a);
tree_fill(tree, arrays);
// Validate the return value of the python function
if (!nb::isinstance<mx::array>(py_value_out)) {

View File

@ -770,6 +770,26 @@ class TestAutograd(mlx_tests.MLXTestCase):
self.assertEqual(mem_pre, mem_post)
def test_grad_with_copies(self):
a = mx.array(2.0)
arrays = [a, a, a]
def fun(arrays):
return arrays[0] + arrays[2]
grads = mx.grad(fun)(arrays)
self.assertEqual(grads[0].item(), 1.0)
self.assertEqual(grads[2].item(), 1.0)
def test_grad_ids_pre_post(self):
def fun(arrs):
return arrs[0]
arrs = [mx.array(1.0)]
init_id = id(arrs[0])
mx.grad(fun)(arrs)
self.assertEqual(init_id, id(arrs[0]))
if __name__ == "__main__":
unittest.main()