Limit grad recursion depth by not recursing through non-grad inputs (#1764)

* limit grad recursion depth

* add grad of module test
This commit is contained in:
Awni Hannun
2025-01-14 14:33:18 -08:00
committed by GitHub
parent 5cc5201914
commit 33421c1dd3
6 changed files with 136 additions and 100 deletions

View File

@@ -139,6 +139,8 @@ class TestAutograd(mlx_tests.MLXTestCase):
mx.value_and_grad(fun, (None, None))
with self.assertRaises(ValueError):
mx.value_and_grad(fun, tuple())
with self.assertRaises(ValueError):
mx.grad(fun, argnums=(0, 0))
def test_auxiliary_values(self):
def fun(x, y):

View File

@@ -195,6 +195,20 @@ class TestBase(mlx_tests.MLXTestCase):
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
def test_grad_of_module(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.m1 = nn.Linear(3, 3)
model = Model()
def loss_fn(model):
return model.m1(x).sum()
x = mx.zeros((3,))
mx.grad(loss_fn)(model)
class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self):