Removes the retain_graph flag (#385)

* Adds global tracing flag
* Removes retain_graph in favor of is_tracer
This commit is contained in:
Angelos Katharopoulos
2024-01-07 15:16:51 -08:00
committed by GitHub
parent 449b43762e
commit a611b0bc82
22 changed files with 209 additions and 207 deletions

View File

@@ -259,6 +259,21 @@ class TestAutograd(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(vjps[0], mx.zeros(shape_in)))
def test_update_state(self):
y = mx.array([1.0])
state = mx.zeros((2,))
def fn(y, x):
nonlocal state
x = y * x
state = state + x
return x.sum()
x = mx.ones((2,))
mx.grad(fn)(y, x)
mx.eval(state)
self.assertTrue(mx.allclose(state, mx.ones((2,))))
if __name__ == "__main__":
unittest.main()