Fix grad copies (#1854)

* fix grad with copies

* add test

* add test
This commit is contained in:
Awni Hannun
2025-02-11 15:26:42 -08:00
committed by GitHub
parent 2a45056ba8
commit 0a5215693e
2 changed files with 22 additions and 2 deletions

View File

@@ -171,12 +171,12 @@ auto py_value_and_grad(
nb::list tree;
tree.append(args);
tree.append(kwargs);
tree_replace(tree, arrays, a);
tree_fill(tree, a);
// Call the python function
py_value_out = fun(*tree[0], **tree[1]);
tree_replace(tree, arrays, a);
tree_fill(tree, arrays);
// Validate the return value of the python function
if (!nb::isinstance<mx::array>(py_value_out)) {