mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	fix grad with inplace updates (#1961)
This commit is contained in:
		@@ -746,6 +746,7 @@ class TestAutograd(mlx_tests.MLXTestCase):
 | 
			
		||||
            mx.checkpoint,
 | 
			
		||||
        ]:
 | 
			
		||||
            if mx.metal.is_available():
 | 
			
		||||
                mx.synchronize(mx.default_stream(mx.default_device()))
 | 
			
		||||
                mem_pre = mx.metal.get_active_memory()
 | 
			
		||||
            else:
 | 
			
		||||
                mem_pre = 0
 | 
			
		||||
@@ -790,6 +791,20 @@ class TestAutograd(mlx_tests.MLXTestCase):
 | 
			
		||||
        mx.grad(fun)(arrs)
 | 
			
		||||
        self.assertEqual(init_id, id(arrs[0]))
 | 
			
		||||
 | 
			
		||||
    def test_grad_with_inplace_update(self):
 | 
			
		||||
        def loss_fn(model):
 | 
			
		||||
            model[1] = mx.array(2.0)
 | 
			
		||||
            return model[0]
 | 
			
		||||
 | 
			
		||||
        model = [
 | 
			
		||||
            mx.array(0.0),
 | 
			
		||||
            mx.array(1.0),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        grad_fn = mx.grad(loss_fn)
 | 
			
		||||
        grad_fn(model)
 | 
			
		||||
        self.assertEqual(model[1].item(), 2.0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user