diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 0530fa089..67863038e 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -124,37 +124,53 @@ auto py_value_and_grad( // Collect the arrays std::vector arrays; + std::vector array_objects; + auto flatten_with_objects = [&arrays, &array_objects]( + auto tree, bool strict) { + tree_visit(tree, [&](nb::handle obj) { + if (nb::isinstance(obj)) { + arrays.push_back(nb::cast(obj)); + array_objects.push_back(nb::borrow(obj)); + } else if (strict) { + throw std::invalid_argument( + "[tree_flatten] The argument should contain only arrays"); + } + }); + }; + std::vector counts(1, 0); std::vector gradient_indices; for (int i = 0, j = 0; i < args.size(); ++i) { bool needs_grad = (j < argnums.size() && argnums[j] == i); - auto argsi = tree_flatten(args[i], /* strict = */ needs_grad); + auto pre_size = arrays.size(); + flatten_with_objects(args[i], /* strict = */ needs_grad); if (needs_grad) { auto old_size = gradient_indices.size(); - gradient_indices.resize(old_size + argsi.size()); + auto delta_size = arrays.size() - pre_size; + gradient_indices.resize(old_size + delta_size); std::iota( gradient_indices.begin() + old_size, gradient_indices.end(), - arrays.size()); + pre_size); j++; - counts.push_back(argsi.size()); + counts.push_back(delta_size); } - arrays.insert(arrays.end(), argsi.begin(), argsi.end()); } for (auto item : kwargs) { bool needs_grad = (argnames.find(nb::cast(item.first)) != argnames.end()); - auto argsk = tree_flatten(item.second, /* strict = */ needs_grad); + auto pre_size = arrays.size(); + flatten_with_objects(item.second, /* strict = */ needs_grad); if (needs_grad) { auto old_size = gradient_indices.size(); - gradient_indices.resize(old_size + argsk.size()); + auto delta_size = arrays.size() - pre_size; + gradient_indices.resize(old_size + delta_size); std::iota( gradient_indices.begin() + old_size, gradient_indices.end(), - arrays.size()); - counts.push_back(argsk.size()); + pre_size); + counts.push_back(delta_size); } - arrays.insert(arrays.end(), argsk.begin(), argsk.end()); } std::partial_sum(counts.cbegin(), counts.cend(), counts.begin()); @@ -163,7 +179,7 @@ auto py_value_and_grad( nb::object py_value_out; auto value_and_grads = mx::value_and_grad( [&fun, - &arrays, + &array_objects, &args, &kwargs, &py_value_out, @@ -183,8 +199,9 @@ auto py_value_and_grad( tree_visit_update(tree, [&](nb::handle node) { auto replace_arr = nb::cast(node); if (replace_arr.id() == a[index].id()) { - return nb::cast(arrays[index++]); + return array_objects[index++]; } else { + index++; return nb::cast(replace_arr); } }); diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 218ea3ce1..c37161a4d 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -780,9 +780,21 @@ class TestAutograd(mlx_tests.MLXTestCase): return arrs[0] arrs = [mx.array(1.0)] - init_id = id(arrs[0]) + arr = arrs[0] mx.grad(fun)(arrs) - self.assertEqual(init_id, id(arrs[0])) + self.assertEqual(id(arr), id(arrs[0])) + + def fun(arrs): + arrs[1] = sum(arrs) + return arrs[1] + + arrs = [mx.array(1.0), mx.array(1.0), mx.array(1.0)] + a_0, a_1, a_2 = arrs + + mx.grad(fun)(arrs) + self.assertEqual(id(a_0), id(arrs[0])) + self.assertNotEqual(id(a_1), id(arrs[1])) + self.assertEqual(id(a_2), id(arrs[2])) def test_grad_with_inplace_update(self): def loss_fn(model): diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index 8d9ee7051..7847a9a60 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -744,7 +744,6 @@ class TestVmap(mlx_tests.MLXTestCase): return Vector([t[0] + 10, t[1] * 10]) x = State(mx.array(1), mx.array(2)) - print(f"{transform(x)=}") vmap_transform = mx.vmap(transform) vmap_transform_tuple = mx.vmap(transform_tuple)