mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
@@ -772,6 +772,23 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
print((out - expected).abs().max())
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
def test_compile_many_inputs(self):
|
||||
inputs = [mx.ones((2, 2, 2, 2)) for _ in range(20)]
|
||||
inputs[0] = inputs[0].T
|
||||
|
||||
@mx.compile
|
||||
def fun(*inputs):
|
||||
x = inputs[0]
|
||||
for y in inputs[1:10]:
|
||||
x = x + y
|
||||
a = inputs[10]
|
||||
for b in inputs[11:]:
|
||||
a = a + b
|
||||
return x + a
|
||||
|
||||
out = fun(*inputs)
|
||||
self.assertTrue(mx.allclose(out, mx.full((2, 2), 20)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user