mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
parent
2a45056ba8
commit
0a5215693e
@ -171,12 +171,12 @@ auto py_value_and_grad(
|
|||||||
nb::list tree;
|
nb::list tree;
|
||||||
tree.append(args);
|
tree.append(args);
|
||||||
tree.append(kwargs);
|
tree.append(kwargs);
|
||||||
tree_replace(tree, arrays, a);
|
tree_fill(tree, a);
|
||||||
|
|
||||||
// Call the python function
|
// Call the python function
|
||||||
py_value_out = fun(*tree[0], **tree[1]);
|
py_value_out = fun(*tree[0], **tree[1]);
|
||||||
|
|
||||||
tree_replace(tree, arrays, a);
|
tree_fill(tree, arrays);
|
||||||
|
|
||||||
// Validate the return value of the python function
|
// Validate the return value of the python function
|
||||||
if (!nb::isinstance<mx::array>(py_value_out)) {
|
if (!nb::isinstance<mx::array>(py_value_out)) {
|
||||||
|
@ -770,6 +770,26 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertEqual(mem_pre, mem_post)
|
self.assertEqual(mem_pre, mem_post)
|
||||||
|
|
||||||
|
def test_grad_with_copies(self):
|
||||||
|
a = mx.array(2.0)
|
||||||
|
arrays = [a, a, a]
|
||||||
|
|
||||||
|
def fun(arrays):
|
||||||
|
return arrays[0] + arrays[2]
|
||||||
|
|
||||||
|
grads = mx.grad(fun)(arrays)
|
||||||
|
self.assertEqual(grads[0].item(), 1.0)
|
||||||
|
self.assertEqual(grads[2].item(), 1.0)
|
||||||
|
|
||||||
|
def test_grad_ids_pre_post(self):
|
||||||
|
def fun(arrs):
|
||||||
|
return arrs[0]
|
||||||
|
|
||||||
|
arrs = [mx.array(1.0)]
|
||||||
|
init_id = id(arrs[0])
|
||||||
|
mx.grad(fun)(arrs)
|
||||||
|
self.assertEqual(init_id, id(arrs[0]))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user