Fix compile (#1501)

* fix compile

* fix space
This commit is contained in:
Awni Hannun
2024-10-18 11:06:40 -07:00
committed by GitHub
parent 50d8bed468
commit 92d7cb71f8
5 changed files with 36 additions and 14 deletions

View File

@@ -772,6 +772,23 @@ class TestCompile(mlx_tests.MLXTestCase):
print((out - expected).abs().max())
self.assertTrue(mx.allclose(out, expected))
def test_compile_many_inputs(self):
inputs = [mx.ones((2, 2, 2, 2)) for _ in range(20)]
inputs[0] = inputs[0].T
@mx.compile
def fun(*inputs):
x = inputs[0]
for y in inputs[1:10]:
x = x + y
a = inputs[10]
for b in inputs[11:]:
a = a + b
return x + a
out = fun(*inputs)
self.assertTrue(mx.allclose(out, mx.full((2, 2), 20)))
if __name__ == "__main__":
unittest.main()