mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:29:35 +08:00
Removes the retain_graph
flag (#385)
* Adds global tracing flag * Removes retain_graph in favor of is_tracer
This commit is contained in:

committed by
GitHub

parent
449b43762e
commit
a611b0bc82
@@ -15,18 +15,13 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(x.tolist(), [[1, 1], [1, 1]])
|
||||
|
||||
def test_retain_graph(self):
|
||||
def fun(x, retain_graph):
|
||||
def fun(x):
|
||||
y = 3 * x
|
||||
mx.eval(y, retain_graph=retain_graph)
|
||||
mx.eval(y)
|
||||
return 2 * y
|
||||
|
||||
dfun_dx_1 = mx.grad(partial(fun, retain_graph=False))
|
||||
dfun_dx_2 = mx.grad(partial(fun, retain_graph=True))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
dfun_dx_1(mx.array(1.0))
|
||||
|
||||
y = dfun_dx_2(mx.array(1.0))
|
||||
dfun_dx = mx.grad(fun)
|
||||
y = dfun_dx(mx.array(1.0))
|
||||
self.assertEqual(y.item(), 6.0)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user