fix grad with inplace updates (#1961)

This commit is contained in:
Awni Hannun
2025-03-13 19:13:09 -07:00
committed by GitHub
parent d2a94f9e6a
commit 2770a10240
2 changed files with 26 additions and 1 deletions

View File

@@ -176,7 +176,17 @@ auto py_value_and_grad(
// Call the python function
py_value_out = fun(*tree[0], **tree[1]);
tree_fill(tree, arrays);
// Replace the tracers with the originals. Don't overwrite
// locations which were written to during the call to fun
int index = 0;
tree_visit_update(tree, [&](nb::handle node) {
auto replace_arr = nb::cast<mx::array>(node);
if (replace_arr.id() == a[index].id()) {
return nb::cast(arrays[index++]);
} else {
return nb::cast(replace_arr);
}
});
// Validate the return value of the python function
if (!nb::isinstance<mx::array>(py_value_out)) {