mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add a test
This commit is contained in:
@@ -1064,6 +1064,38 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
out = fun(mx.array(1.0), mx.array(2.0))
|
out = fun(mx.array(1.0), mx.array(2.0))
|
||||||
self.assertEqual(out.item(), 3.0)
|
self.assertEqual(out.item(), 3.0)
|
||||||
|
|
||||||
|
def test_compile_changing_outputs(self):
|
||||||
|
@mx.compile
|
||||||
|
def fun(x, y):
|
||||||
|
if y is None:
|
||||||
|
return 2 * x
|
||||||
|
elif (
|
||||||
|
isinstance(x, mx.array)
|
||||||
|
and isinstance(y, mx.array)
|
||||||
|
and x.dtype == y.dtype == mx.float32
|
||||||
|
):
|
||||||
|
return [x + y]
|
||||||
|
elif y.dtype == mx.bool_:
|
||||||
|
return {"a": x, "b": y * x}
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
a = fun(mx.array(1.0), mx.array(2.0))
|
||||||
|
self.assertTrue(isinstance(a, list))
|
||||||
|
self.assertEqual(a[0].item(), 3.0)
|
||||||
|
|
||||||
|
b = fun(mx.array(1.0), mx.array(True))
|
||||||
|
self.assertTrue(isinstance(b, dict))
|
||||||
|
self.assertEqual(b["a"].item(), 1.0)
|
||||||
|
self.assertEqual(b["b"].item(), 1.0)
|
||||||
|
|
||||||
|
c = fun(mx.array(1.0), None)
|
||||||
|
self.assertTrue(isinstance(c, mx.array))
|
||||||
|
self.assertEqual(c.item(), 2.0)
|
||||||
|
|
||||||
|
d = fun(False, mx.array(1.0))
|
||||||
|
self.assertTrue(d is None)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mlx_tests.MLXTestRunner()
|
mlx_tests.MLXTestRunner()
|
||||||
|
|||||||
Reference in New Issue
Block a user