mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Limit grad recursion depth by not recursing through non-grad inputs (#1764)
* limit grad recursion depth * add grad of module test
This commit is contained in:
		| @@ -139,6 +139,8 @@ class TestAutograd(mlx_tests.MLXTestCase): | ||||
|             mx.value_and_grad(fun, (None, None)) | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.value_and_grad(fun, tuple()) | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.grad(fun, argnums=(0, 0)) | ||||
|  | ||||
|     def test_auxiliary_values(self): | ||||
|         def fun(x, y): | ||||
|   | ||||
| @@ -195,6 +195,20 @@ class TestBase(mlx_tests.MLXTestCase): | ||||
|         self.assertTrue(isinstance(m.layers[1], nn.ReLU)) | ||||
|         self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear)) | ||||
|  | ||||
|     def test_grad_of_module(self): | ||||
|         class Model(nn.Module): | ||||
|             def __init__(self): | ||||
|                 super().__init__() | ||||
|                 self.m1 = nn.Linear(3, 3) | ||||
|  | ||||
|         model = Model() | ||||
|  | ||||
|         def loss_fn(model): | ||||
|             return model.m1(x).sum() | ||||
|  | ||||
|         x = mx.zeros((3,)) | ||||
|         mx.grad(loss_fn)(model) | ||||
|  | ||||
|  | ||||
| class TestLayers(mlx_tests.MLXTestCase): | ||||
|     def test_identity(self): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun