mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-07 17:44:38 +08:00
More buffer donation with no-ops (#1591)
* more donation * fix test * fix build
This commit is contained in:
@@ -137,6 +137,43 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
mx.async_eval(x)
|
||||
mx.eval(a + b)
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_donation_for_noops(self):
|
||||
def fun(x):
|
||||
s = x.shape
|
||||
for _ in range(10):
|
||||
x = mx.abs(x)
|
||||
x = mx.reshape(x, (-1,))
|
||||
x = x.T.T
|
||||
x = mx.stop_gradient(x)
|
||||
x = mx.abs(x)
|
||||
return x
|
||||
|
||||
x = mx.zeros((4096, 4096))
|
||||
mx.eval(x)
|
||||
pre = mx.metal.get_peak_memory()
|
||||
out = fun(x)
|
||||
del x
|
||||
mx.eval(out)
|
||||
post = mx.metal.get_peak_memory()
|
||||
self.assertEqual(pre, post)
|
||||
|
||||
def fun(x):
|
||||
for _ in range(10):
|
||||
x = mx.abs(x)
|
||||
x = x[:-1]
|
||||
x = mx.abs(x)
|
||||
return x
|
||||
|
||||
x = mx.zeros((4096 * 4096,))
|
||||
mx.eval(x)
|
||||
pre = mx.metal.get_peak_memory()
|
||||
out = fun(x)
|
||||
del x
|
||||
mx.eval(out)
|
||||
post = mx.metal.get_peak_memory()
|
||||
self.assertEqual(pre, post)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user