mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
More buffer donation in some cases (#1858)
* more donation * fix * add test
This commit is contained in:
@@ -174,6 +174,25 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
post = mx.metal.get_peak_memory()
|
||||
self.assertEqual(pre, post)
|
||||
|
||||
def test_donation_multiple_inputs(self):
|
||||
def fun(its, x, y):
|
||||
for _ in range(its):
|
||||
a = x + y # y should donate
|
||||
b = x + a # x should donate
|
||||
x, y = a, b
|
||||
return x, y
|
||||
|
||||
x = mx.zeros((128, 128))
|
||||
y = mx.zeros((128, 128))
|
||||
mx.metal.reset_peak_memory()
|
||||
a, b = fun(2, x, y)
|
||||
mx.eval(a, b)
|
||||
mem2 = mx.metal.get_peak_memory()
|
||||
a, b = fun(10, x, y)
|
||||
mx.eval(a, b)
|
||||
mem10 = mx.metal.get_peak_memory()
|
||||
self.assertEqual(mem2, mem10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user