mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-10 21:37:50 +08:00
Limit grad recursion depth by not recursing through non-grad inputs (#1764)
* limit grad recursion depth * add grad of module test
This commit is contained in:
@@ -139,6 +139,8 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
mx.value_and_grad(fun, (None, None))
|
||||
with self.assertRaises(ValueError):
|
||||
mx.value_and_grad(fun, tuple())
|
||||
with self.assertRaises(ValueError):
|
||||
mx.grad(fun, argnums=(0, 0))
|
||||
|
||||
def test_auxiliary_values(self):
|
||||
def fun(x, y):
|
||||
|
@@ -195,6 +195,20 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
|
||||
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
|
||||
|
||||
def test_grad_of_module(self):
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.m1 = nn.Linear(3, 3)
|
||||
|
||||
model = Model()
|
||||
|
||||
def loss_fn(model):
|
||||
return model.m1(x).sum()
|
||||
|
||||
x = mx.zeros((3,))
|
||||
mx.grad(loss_fn)(model)
|
||||
|
||||
|
||||
class TestLayers(mlx_tests.MLXTestCase):
|
||||
def test_identity(self):
|
||||
|
Reference in New Issue
Block a user