mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 01:19:21 +08:00
fix grad with inplace updates (#1961)
This commit is contained in:
@@ -176,7 +176,17 @@ auto py_value_and_grad(
|
||||
// Call the python function
|
||||
py_value_out = fun(*tree[0], **tree[1]);
|
||||
|
||||
tree_fill(tree, arrays);
|
||||
// Replace the tracers with the originals. Don't overwrite
|
||||
// locations which were written to during the call to fun
|
||||
int index = 0;
|
||||
tree_visit_update(tree, [&](nb::handle node) {
|
||||
auto replace_arr = nb::cast<mx::array>(node);
|
||||
if (replace_arr.id() == a[index].id()) {
|
||||
return nb::cast(arrays[index++]);
|
||||
} else {
|
||||
return nb::cast(replace_arr);
|
||||
}
|
||||
});
|
||||
|
||||
// Validate the return value of the python function
|
||||
if (!nb::isinstance<mx::array>(py_value_out)) {
|
||||
|
||||
Reference in New Issue
Block a user