More buffer donation with no-ops (#1591)

* more donation

* fix test

* fix build
This commit is contained in:
Awni Hannun
2024-11-18 08:35:41 -08:00
committed by GitHub
parent 6931f84412
commit 9bd03dd9b4
7 changed files with 82 additions and 13 deletions

View File

@@ -137,6 +137,43 @@ class TestEval(mlx_tests.MLXTestCase):
mx.async_eval(x)
mx.eval(a + b)
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_donation_for_noops(self):
def fun(x):
s = x.shape
for _ in range(10):
x = mx.abs(x)
x = mx.reshape(x, (-1,))
x = x.T.T
x = mx.stop_gradient(x)
x = mx.abs(x)
return x
x = mx.zeros((4096, 4096))
mx.eval(x)
pre = mx.metal.get_peak_memory()
out = fun(x)
del x
mx.eval(out)
post = mx.metal.get_peak_memory()
self.assertEqual(pre, post)
def fun(x):
for _ in range(10):
x = mx.abs(x)
x = x[:-1]
x = mx.abs(x)
return x
x = mx.zeros((4096 * 4096,))
mx.eval(x)
pre = mx.metal.get_peak_memory()
out = fun(x)
del x
mx.eval(out)
post = mx.metal.get_peak_memory()
self.assertEqual(pre, post)
if __name__ == "__main__":
unittest.main()