From cab27b7f0d485ce6d2db20c2f3b6518ed17f8c0f Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 29 Sep 2025 23:47:52 -0700 Subject: [PATCH] Add a test --- python/tests/test_compile.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 5528b094e..65195808c 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -1064,6 +1064,38 @@ class TestCompile(mlx_tests.MLXTestCase): out = fun(mx.array(1.0), mx.array(2.0)) self.assertEqual(out.item(), 3.0) + def test_compile_changing_outputs(self): + @mx.compile + def fun(x, y): + if y is None: + return 2 * x + elif ( + isinstance(x, mx.array) + and isinstance(y, mx.array) + and x.dtype == y.dtype == mx.float32 + ): + return [x + y] + elif y.dtype == mx.bool_: + return {"a": x, "b": y * x} + else: + return None + + a = fun(mx.array(1.0), mx.array(2.0)) + self.assertTrue(isinstance(a, list)) + self.assertEqual(a[0].item(), 3.0) + + b = fun(mx.array(1.0), mx.array(True)) + self.assertTrue(isinstance(b, dict)) + self.assertEqual(b["a"].item(), 1.0) + self.assertEqual(b["b"].item(), 1.0) + + c = fun(mx.array(1.0), None) + self.assertTrue(isinstance(c, mx.array)) + self.assertEqual(c.item(), 2.0) + + d = fun(False, mx.array(1.0)) + self.assertTrue(d is None) + if __name__ == "__main__": mlx_tests.MLXTestRunner()