Removes the retain_graph flag (#385)

* Adds global tracing flag
* Removes retain_graph in favor of is_tracer
This commit is contained in:
Angelos Katharopoulos
2024-01-07 15:16:51 -08:00
committed by GitHub
parent 449b43762e
commit a611b0bc82
22 changed files with 209 additions and 207 deletions

View File

@@ -259,6 +259,21 @@ class TestAutograd(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(vjps[0], mx.zeros(shape_in)))
def test_update_state(self):
y = mx.array([1.0])
state = mx.zeros((2,))
def fn(y, x):
nonlocal state
x = y * x
state = state + x
return x.sum()
x = mx.ones((2,))
mx.grad(fn)(y, x)
mx.eval(state)
self.assertTrue(mx.allclose(state, mx.ones((2,))))
if __name__ == "__main__":
unittest.main()

View File

@@ -15,18 +15,13 @@ class TestEval(mlx_tests.MLXTestCase):
self.assertEqual(x.tolist(), [[1, 1], [1, 1]])
def test_retain_graph(self):
def fun(x, retain_graph):
def fun(x):
y = 3 * x
mx.eval(y, retain_graph=retain_graph)
mx.eval(y)
return 2 * y
dfun_dx_1 = mx.grad(partial(fun, retain_graph=False))
dfun_dx_2 = mx.grad(partial(fun, retain_graph=True))
with self.assertRaises(ValueError):
dfun_dx_1(mx.array(1.0))
y = dfun_dx_2(mx.array(1.0))
dfun_dx = mx.grad(fun)
y = dfun_dx(mx.array(1.0))
self.assertEqual(y.item(), 6.0)