mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-17 23:08:11 +08:00
fix for max block dim (#2631)
This commit is contained in:
@@ -828,6 +828,19 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
out = fun(arrs)
|
||||
self.assertTrue(mx.allclose(out, mx.array([64.0, 128.0])))
|
||||
|
||||
inputs = [mx.arange(16384).astype(mx.float16) for _ in range(8)]
|
||||
|
||||
def fun(inputs):
|
||||
a = inputs[0] + inputs[1]
|
||||
b = inputs[2] + inputs[3]
|
||||
c = inputs[4] + inputs[5]
|
||||
d = inputs[6] + inputs[7]
|
||||
return a * b * c * d
|
||||
|
||||
out = mx.compile(fun)(inputs)
|
||||
expected = fun(inputs)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
def test_compile_many_outputs(self):
|
||||
|
||||
@mx.compile
|
||||
|
Reference in New Issue
Block a user