fix donation in scan (#1917)

This commit is contained in:
Awni Hannun
2025-03-03 11:30:59 -08:00
committed by GitHub
parent ba12e4999a
commit 6bcd6bcf70
2 changed files with 19 additions and 6 deletions

View File

@@ -1894,6 +1894,22 @@ class TestOps(mlx_tests.MLXTestCase):
expected = mx.repeat(expected[:, None], 2, axis=1)
self.assertTrue(mx.array_equal(expected, out))
# Test donation
def fn(its):
x = mx.ones((32,))
for _ in range(its):
x = mx.cumsum(x)
return x
mx.synchronize(mx.default_stream(mx.default_device()))
mx.eval(fn(2))
mx.synchronize(mx.default_stream(mx.default_device()))
mem2 = mx.metal.get_peak_memory()
mx.eval(fn(4))
mx.synchronize(mx.default_stream(mx.default_device()))
mem4 = mx.metal.get_peak_memory()
self.assertEqual(mem2, mem4)
def test_squeeze_expand(self):
a = mx.zeros((2, 1, 2, 1))
self.assertEqual(mx.squeeze(a).shape, (2, 2))