More buffer donation in some cases (#1858)

* more donation

* fix

* add test
This commit is contained in:
Awni Hannun
2025-02-12 19:41:37 -08:00
committed by GitHub
parent 55c5ac7820
commit d274ae77f2
2 changed files with 40 additions and 4 deletions

View File

@@ -174,6 +174,25 @@ class TestEval(mlx_tests.MLXTestCase):
post = mx.metal.get_peak_memory()
self.assertEqual(pre, post)
def test_donation_multiple_inputs(self):
def fun(its, x, y):
for _ in range(its):
a = x + y # y should donate
b = x + a # x should donate
x, y = a, b
return x, y
x = mx.zeros((128, 128))
y = mx.zeros((128, 128))
mx.metal.reset_peak_memory()
a, b = fun(2, x, y)
mx.eval(a, b)
mem2 = mx.metal.get_peak_memory()
a, b = fun(10, x, y)
mx.eval(a, b)
mem10 = mx.metal.get_peak_memory()
self.assertEqual(mem2, mem10)
if __name__ == "__main__":
unittest.main()