diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 5528b094e..65195808c 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -1064,6 +1064,38 @@ class TestCompile(mlx_tests.MLXTestCase): out = fun(mx.array(1.0), mx.array(2.0)) self.assertEqual(out.item(), 3.0) + def test_compile_changing_outputs(self): + @mx.compile + def fun(x, y): + if y is None: + return 2 * x + elif ( + isinstance(x, mx.array) + and isinstance(y, mx.array) + and x.dtype == y.dtype == mx.float32 + ): + return [x + y] + elif y.dtype == mx.bool_: + return {"a": x, "b": y * x} + else: + return None + + a = fun(mx.array(1.0), mx.array(2.0)) + self.assertTrue(isinstance(a, list)) + self.assertEqual(a[0].item(), 3.0) + + b = fun(mx.array(1.0), mx.array(True)) + self.assertTrue(isinstance(b, dict)) + self.assertEqual(b["a"].item(), 1.0) + self.assertEqual(b["b"].item(), 1.0) + + c = fun(mx.array(1.0), None) + self.assertTrue(isinstance(c, mx.array)) + self.assertEqual(c.item(), 2.0) + + d = fun(False, mx.array(1.0)) + self.assertTrue(d is None) + if __name__ == "__main__": mlx_tests.MLXTestRunner()