fix for max block dim (#2631)

This commit is contained in:
Awni Hannun
2025-09-29 08:59:25 -07:00
committed by GitHub
parent e76a8dd5c5
commit dc371ae7a5
7 changed files with 67 additions and 21 deletions

View File

@@ -828,6 +828,19 @@ class TestCompile(mlx_tests.MLXTestCase):
out = fun(arrs)
self.assertTrue(mx.allclose(out, mx.array([64.0, 128.0])))
inputs = [mx.arange(16384).astype(mx.float16) for _ in range(8)]
def fun(inputs):
a = inputs[0] + inputs[1]
b = inputs[2] + inputs[3]
c = inputs[4] + inputs[5]
d = inputs[6] + inputs[7]
return a * b * c * d
out = mx.compile(fun)(inputs)
expected = fun(inputs)
self.assertTrue(mx.allclose(out, expected))
def test_compile_many_outputs(self):
@mx.compile