Limit grad recursion depth by not recursing through non-grad inputs (#1764)

* limit grad recursion depth

* add grad of module test
This commit is contained in:
Awni Hannun
2025-01-14 14:33:18 -08:00
committed by GitHub
parent 5cc5201914
commit 33421c1dd3
6 changed files with 136 additions and 100 deletions

View File

@@ -139,6 +139,8 @@ class TestAutograd(mlx_tests.MLXTestCase):
mx.value_and_grad(fun, (None, None))
with self.assertRaises(ValueError):
mx.value_and_grad(fun, tuple())
with self.assertRaises(ValueError):
mx.grad(fun, argnums=(0, 0))
def test_auxiliary_values(self):
def fun(x, y):