mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-09 10:46:39 +08:00
fix grad with inplace updates (#1961)
This commit is contained in:
parent
d2a94f9e6a
commit
2770a10240
@ -176,7 +176,17 @@ auto py_value_and_grad(
|
|||||||
// Call the python function
|
// Call the python function
|
||||||
py_value_out = fun(*tree[0], **tree[1]);
|
py_value_out = fun(*tree[0], **tree[1]);
|
||||||
|
|
||||||
tree_fill(tree, arrays);
|
// Replace the tracers with the originals. Don't overwrite
|
||||||
|
// locations which were written to during the call to fun
|
||||||
|
int index = 0;
|
||||||
|
tree_visit_update(tree, [&](nb::handle node) {
|
||||||
|
auto replace_arr = nb::cast<mx::array>(node);
|
||||||
|
if (replace_arr.id() == a[index].id()) {
|
||||||
|
return nb::cast(arrays[index++]);
|
||||||
|
} else {
|
||||||
|
return nb::cast(replace_arr);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
// Validate the return value of the python function
|
// Validate the return value of the python function
|
||||||
if (!nb::isinstance<mx::array>(py_value_out)) {
|
if (!nb::isinstance<mx::array>(py_value_out)) {
|
||||||
|
@ -746,6 +746,7 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
|||||||
mx.checkpoint,
|
mx.checkpoint,
|
||||||
]:
|
]:
|
||||||
if mx.metal.is_available():
|
if mx.metal.is_available():
|
||||||
|
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||||
mem_pre = mx.metal.get_active_memory()
|
mem_pre = mx.metal.get_active_memory()
|
||||||
else:
|
else:
|
||||||
mem_pre = 0
|
mem_pre = 0
|
||||||
@ -790,6 +791,20 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
|||||||
mx.grad(fun)(arrs)
|
mx.grad(fun)(arrs)
|
||||||
self.assertEqual(init_id, id(arrs[0]))
|
self.assertEqual(init_id, id(arrs[0]))
|
||||||
|
|
||||||
|
def test_grad_with_inplace_update(self):
|
||||||
|
def loss_fn(model):
|
||||||
|
model[1] = mx.array(2.0)
|
||||||
|
return model[0]
|
||||||
|
|
||||||
|
model = [
|
||||||
|
mx.array(0.0),
|
||||||
|
mx.array(1.0),
|
||||||
|
]
|
||||||
|
|
||||||
|
grad_fn = mx.grad(loss_fn)
|
||||||
|
grad_fn(model)
|
||||||
|
self.assertEqual(model[1].item(), 2.0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user