mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-10 13:07:29 +08:00
@@ -770,6 +770,26 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertEqual(mem_pre, mem_post)
|
||||
|
||||
def test_grad_with_copies(self):
|
||||
a = mx.array(2.0)
|
||||
arrays = [a, a, a]
|
||||
|
||||
def fun(arrays):
|
||||
return arrays[0] + arrays[2]
|
||||
|
||||
grads = mx.grad(fun)(arrays)
|
||||
self.assertEqual(grads[0].item(), 1.0)
|
||||
self.assertEqual(grads[2].item(), 1.0)
|
||||
|
||||
def test_grad_ids_pre_post(self):
|
||||
def fun(arrs):
|
||||
return arrs[0]
|
||||
|
||||
arrs = [mx.array(1.0)]
|
||||
init_id = id(arrs[0])
|
||||
mx.grad(fun)(arrs)
|
||||
self.assertEqual(init_id, id(arrs[0]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user