Limit compile buffers (#1887)

* limit compile buffers

* maybe not flaky test
This commit is contained in:
Awni Hannun
2025-02-19 20:28:13 -08:00
committed by GitHub
parent 78ba24c37d
commit c707b2b0a6
4 changed files with 43 additions and 1 deletions

View File

@@ -177,6 +177,7 @@ class TestDistributed(mlx_tests.MLXTestCase):
def test_donation(self):
x = mx.random.normal((1024,))
mx.eval(x)
mx.synchronize(mx.default_stream(mx.default_device()))
mx.metal.reset_peak_memory()
scale = mx.array(2.0)

View File

@@ -815,6 +815,31 @@ class TestCompile(mlx_tests.MLXTestCase):
out = fun(*inputs)
self.assertTrue(mx.allclose(out, mx.full((2, 2), 20)))
@mx.compile
def fun(arrs):
for _ in range(6):
arrs = [x + y for x, y in zip(arrs[::2], arrs[1::2])]
return arrs[0]
arrs = [mx.array([1.0, 2.0]) for _ in range(64)]
out = fun(arrs)
self.assertTrue(mx.allclose(out, mx.array([64.0, 128.0])))
def test_compile_many_outputs(self):
@mx.compile
def fun(arr):
arrs = [arr] * 64
first_arrs = None
for _ in range(6):
arrs = [x + y for x, y in zip(arrs[::2], arrs[1::2])]
if first_arrs is None:
first_arrs = arrs
return arrs[0], first_arrs
out = fun(mx.array([1.0, 2.0]))
self.assertTrue(mx.allclose(out[0], mx.array([64.0, 128.0])))
def test_shapeless_compile_matmul(self):
a = mx.array([0.0, 1.0, 2.0])
b = mx.array([0.0, 1.0, 2.0])

View File

@@ -385,6 +385,7 @@ class TestLoad(mlx_tests.MLXTestCase):
mx.eval(x)
save_file = os.path.join(self.test_dir, "donation.npy")
mx.save(save_file, x)
mx.synchronize(mx.default_stream(mx.default_device()))
mx.metal.reset_peak_memory()
scale = mx.array(2.0)