mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Fix leak with multi-output primitives (#1274)
* fix leak with multi-output primitives * hopefully an actual fix
This commit is contained in:
@@ -103,6 +103,23 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
z = mx.add(y, x, stream=mx.cpu)
|
||||
self.assertTrue(mx.allclose(z, mx.full((8000,), 22.0)))
|
||||
|
||||
def test_multi_output_eval_during_transform(self):
|
||||
x = mx.random.uniform(shape=(1024,))
|
||||
y = mx.ones((1024,))
|
||||
mx.eval(x, y)
|
||||
|
||||
def fn(x):
|
||||
a, b = mx.divmod(x, x)
|
||||
mx.eval(a)
|
||||
return a
|
||||
|
||||
out = mx.vjp(fn, (x,), (y,))
|
||||
out = mx.vjp(fn, (x,), (y,))
|
||||
if mx.metal.is_available():
|
||||
peak_mem = mx.metal.get_peak_memory()
|
||||
out = mx.vjp(fn, (x,), (y,))
|
||||
self.assertEqual(peak_mem, mx.metal.get_peak_memory())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user