Compile now can attach arbitrary data to an entry (#2634)

This commit is contained in:
Angelos Katharopoulos
2025-09-30 13:33:27 -07:00
committed by GitHub
parent dc371ae7a5
commit eb24267b56
5 changed files with 130 additions and 31 deletions

View File

@@ -1064,6 +1064,57 @@ class TestCompile(mlx_tests.MLXTestCase):
out = fun(mx.array(1.0), mx.array(2.0))
self.assertEqual(out.item(), 3.0)
def test_compile_changing_outputs(self):
@mx.compile
def fun(x, y):
if y is None:
return 2 * x
elif (
isinstance(x, mx.array)
and isinstance(y, mx.array)
and x.dtype == y.dtype == mx.float32
):
return [x + y]
elif y.dtype == mx.bool_:
return {"a": x, "b": y * x}
else:
return None
a = fun(mx.array(1.0), mx.array(2.0))
self.assertTrue(isinstance(a, list))
self.assertEqual(a[0].item(), 3.0)
b = fun(mx.array(1.0), mx.array(True))
self.assertTrue(isinstance(b, dict))
self.assertEqual(b["a"].item(), 1.0)
self.assertEqual(b["b"].item(), 1.0)
c = fun(mx.array(1.0), None)
self.assertTrue(isinstance(c, mx.array))
self.assertEqual(c.item(), 2.0)
d = fun(False, mx.array(1.0))
self.assertTrue(d is None)
def test_compile_changing_outputs_with_state(self):
state = [mx.array(1.0)]
@partial(mx.compile, inputs=state, outputs=state)
def fun(y):
x = state[0]
if y.dtype == mx.float32:
state[0] = 2 * y
return [x, y, x + y]
elif y.dtype == mx.int32:
state[0] *= 2
return x + y
for i in range(10):
fun(mx.array(1.0))
fun(mx.array(1))
self.assertEqual(state[0].item(), 4)
if __name__ == "__main__":
mlx_tests.MLXTestRunner()