Removes the retain_graph flag (#385)

* Adds global tracing flag
* Removes retain_graph in favor of is_tracer
This commit is contained in:
Angelos Katharopoulos
2024-01-07 15:16:51 -08:00
committed by GitHub
parent 449b43762e
commit a611b0bc82
22 changed files with 209 additions and 207 deletions

View File

@@ -15,18 +15,13 @@ class TestEval(mlx_tests.MLXTestCase):
self.assertEqual(x.tolist(), [[1, 1], [1, 1]])
def test_retain_graph(self):
def fun(x, retain_graph):
def fun(x):
y = 3 * x
mx.eval(y, retain_graph=retain_graph)
mx.eval(y)
return 2 * y
dfun_dx_1 = mx.grad(partial(fun, retain_graph=False))
dfun_dx_2 = mx.grad(partial(fun, retain_graph=True))
with self.assertRaises(ValueError):
dfun_dx_1(mx.array(1.0))
y = dfun_dx_2(mx.array(1.0))
dfun_dx = mx.grad(fun)
y = dfun_dx(mx.array(1.0))
self.assertEqual(y.item(), 6.0)