Add a test

This commit is contained in:
Angelos Katharopoulos
2025-09-29 23:47:52 -07:00
parent 6faf593775
commit cab27b7f0d

View File

@@ -1064,6 +1064,38 @@ class TestCompile(mlx_tests.MLXTestCase):
out = fun(mx.array(1.0), mx.array(2.0))
self.assertEqual(out.item(), 3.0)
def test_compile_changing_outputs(self):
@mx.compile
def fun(x, y):
if y is None:
return 2 * x
elif (
isinstance(x, mx.array)
and isinstance(y, mx.array)
and x.dtype == y.dtype == mx.float32
):
return [x + y]
elif y.dtype == mx.bool_:
return {"a": x, "b": y * x}
else:
return None
a = fun(mx.array(1.0), mx.array(2.0))
self.assertTrue(isinstance(a, list))
self.assertEqual(a[0].item(), 3.0)
b = fun(mx.array(1.0), mx.array(True))
self.assertTrue(isinstance(b, dict))
self.assertEqual(b["a"].item(), 1.0)
self.assertEqual(b["b"].item(), 1.0)
c = fun(mx.array(1.0), None)
self.assertTrue(isinstance(c, mx.array))
self.assertEqual(c.item(), 2.0)
d = fun(False, mx.array(1.0))
self.assertTrue(d is None)
if __name__ == "__main__":
mlx_tests.MLXTestRunner()