mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-09 07:53:55 +08:00
Removes the retain_graph
flag (#385)
* Adds global tracing flag * Removes retain_graph in favor of is_tracer
This commit is contained in:

committed by
GitHub

parent
449b43762e
commit
a611b0bc82
@@ -259,6 +259,21 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertTrue(mx.allclose(vjps[0], mx.zeros(shape_in)))
|
||||
|
||||
def test_update_state(self):
|
||||
y = mx.array([1.0])
|
||||
state = mx.zeros((2,))
|
||||
|
||||
def fn(y, x):
|
||||
nonlocal state
|
||||
x = y * x
|
||||
state = state + x
|
||||
return x.sum()
|
||||
|
||||
x = mx.ones((2,))
|
||||
mx.grad(fn)(y, x)
|
||||
mx.eval(state)
|
||||
self.assertTrue(mx.allclose(state, mx.ones((2,))))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user