Fix leak with multi-output primitives (#1274)

* fix leak with multi-output primitives

* hopefully an actual fix
This commit is contained in:
Awni Hannun
2024-07-23 06:34:18 -07:00
committed by GitHub
parent df124e018a
commit 1fba87b0df
4 changed files with 29 additions and 5 deletions

View File

@@ -103,6 +103,23 @@ class TestEval(mlx_tests.MLXTestCase):
z = mx.add(y, x, stream=mx.cpu)
self.assertTrue(mx.allclose(z, mx.full((8000,), 22.0)))
def test_multi_output_eval_during_transform(self):
x = mx.random.uniform(shape=(1024,))
y = mx.ones((1024,))
mx.eval(x, y)
def fn(x):
a, b = mx.divmod(x, x)
mx.eval(a)
return a
out = mx.vjp(fn, (x,), (y,))
out = mx.vjp(fn, (x,), (y,))
if mx.metal.is_available():
peak_mem = mx.metal.get_peak_memory()
out = mx.vjp(fn, (x,), (y,))
self.assertEqual(peak_mem, mx.metal.get_peak_memory())
if __name__ == "__main__":
unittest.main()