mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-09 10:46:39 +08:00
fix grad with inplace updates (#1961)
This commit is contained in:
parent
d2a94f9e6a
commit
2770a10240
@ -176,7 +176,17 @@ auto py_value_and_grad(
|
||||
// Call the python function
|
||||
py_value_out = fun(*tree[0], **tree[1]);
|
||||
|
||||
tree_fill(tree, arrays);
|
||||
// Replace the tracers with the originals. Don't overwrite
|
||||
// locations which were written to during the call to fun
|
||||
int index = 0;
|
||||
tree_visit_update(tree, [&](nb::handle node) {
|
||||
auto replace_arr = nb::cast<mx::array>(node);
|
||||
if (replace_arr.id() == a[index].id()) {
|
||||
return nb::cast(arrays[index++]);
|
||||
} else {
|
||||
return nb::cast(replace_arr);
|
||||
}
|
||||
});
|
||||
|
||||
// Validate the return value of the python function
|
||||
if (!nb::isinstance<mx::array>(py_value_out)) {
|
||||
|
@ -746,6 +746,7 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
mx.checkpoint,
|
||||
]:
|
||||
if mx.metal.is_available():
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
mem_pre = mx.metal.get_active_memory()
|
||||
else:
|
||||
mem_pre = 0
|
||||
@ -790,6 +791,20 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
mx.grad(fun)(arrs)
|
||||
self.assertEqual(init_id, id(arrs[0]))
|
||||
|
||||
def test_grad_with_inplace_update(self):
|
||||
def loss_fn(model):
|
||||
model[1] = mx.array(2.0)
|
||||
return model[0]
|
||||
|
||||
model = [
|
||||
mx.array(0.0),
|
||||
mx.array(1.0),
|
||||
]
|
||||
|
||||
grad_fn = mx.grad(loss_fn)
|
||||
grad_fn(model)
|
||||
self.assertEqual(model[1].item(), 2.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user