mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Removes the retain_graph flag (#385)
				
					
				
			* Adds global tracing flag * Removes retain_graph in favor of is_tracer
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							449b43762e
						
					
				
				
					commit
					a611b0bc82
				
			| @@ -259,6 +259,21 @@ class TestAutograd(mlx_tests.MLXTestCase): | ||||
|  | ||||
|         self.assertTrue(mx.allclose(vjps[0], mx.zeros(shape_in))) | ||||
|  | ||||
|     def test_update_state(self): | ||||
|         y = mx.array([1.0]) | ||||
|         state = mx.zeros((2,)) | ||||
|  | ||||
|         def fn(y, x): | ||||
|             nonlocal state | ||||
|             x = y * x | ||||
|             state = state + x | ||||
|             return x.sum() | ||||
|  | ||||
|         x = mx.ones((2,)) | ||||
|         mx.grad(fn)(y, x) | ||||
|         mx.eval(state) | ||||
|         self.assertTrue(mx.allclose(state, mx.ones((2,)))) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
| @@ -15,18 +15,13 @@ class TestEval(mlx_tests.MLXTestCase): | ||||
|             self.assertEqual(x.tolist(), [[1, 1], [1, 1]]) | ||||
|  | ||||
|     def test_retain_graph(self): | ||||
|         def fun(x, retain_graph): | ||||
|         def fun(x): | ||||
|             y = 3 * x | ||||
|             mx.eval(y, retain_graph=retain_graph) | ||||
|             mx.eval(y) | ||||
|             return 2 * y | ||||
|  | ||||
|         dfun_dx_1 = mx.grad(partial(fun, retain_graph=False)) | ||||
|         dfun_dx_2 = mx.grad(partial(fun, retain_graph=True)) | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|             dfun_dx_1(mx.array(1.0)) | ||||
|  | ||||
|         y = dfun_dx_2(mx.array(1.0)) | ||||
|         dfun_dx = mx.grad(fun) | ||||
|         y = dfun_dx(mx.array(1.0)) | ||||
|         self.assertEqual(y.item(), 6.0) | ||||
|  | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user