Limit compile buffers (#1887)

* limit compile buffers

* maybe not flaky test
This commit is contained in:
Awni Hannun
2025-02-19 20:28:13 -08:00
committed by GitHub
parent 78ba24c37d
commit c707b2b0a6
4 changed files with 43 additions and 1 deletions

View File

@@ -815,6 +815,31 @@ class TestCompile(mlx_tests.MLXTestCase):
out = fun(*inputs)
self.assertTrue(mx.allclose(out, mx.full((2, 2), 20)))
@mx.compile
def fun(arrs):
for _ in range(6):
arrs = [x + y for x, y in zip(arrs[::2], arrs[1::2])]
return arrs[0]
arrs = [mx.array([1.0, 2.0]) for _ in range(64)]
out = fun(arrs)
self.assertTrue(mx.allclose(out, mx.array([64.0, 128.0])))
def test_compile_many_outputs(self):
@mx.compile
def fun(arr):
arrs = [arr] * 64
first_arrs = None
for _ in range(6):
arrs = [x + y for x, y in zip(arrs[::2], arrs[1::2])]
if first_arrs is None:
first_arrs = arrs
return arrs[0], first_arrs
out = fun(mx.array([1.0, 2.0]))
self.assertTrue(mx.allclose(out[0], mx.array([64.0, 128.0])))
def test_shapeless_compile_matmul(self):
a = mx.array([0.0, 1.0, 2.0])
b = mx.array([0.0, 1.0, 2.0])