mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-17 23:08:11 +08:00
Compile now can attach arbitrary data to an entry (#2634)
This commit is contained in:

committed by
GitHub

parent
dc371ae7a5
commit
eb24267b56
@@ -1064,6 +1064,57 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
out = fun(mx.array(1.0), mx.array(2.0))
|
||||
self.assertEqual(out.item(), 3.0)
|
||||
|
||||
def test_compile_changing_outputs(self):
|
||||
@mx.compile
|
||||
def fun(x, y):
|
||||
if y is None:
|
||||
return 2 * x
|
||||
elif (
|
||||
isinstance(x, mx.array)
|
||||
and isinstance(y, mx.array)
|
||||
and x.dtype == y.dtype == mx.float32
|
||||
):
|
||||
return [x + y]
|
||||
elif y.dtype == mx.bool_:
|
||||
return {"a": x, "b": y * x}
|
||||
else:
|
||||
return None
|
||||
|
||||
a = fun(mx.array(1.0), mx.array(2.0))
|
||||
self.assertTrue(isinstance(a, list))
|
||||
self.assertEqual(a[0].item(), 3.0)
|
||||
|
||||
b = fun(mx.array(1.0), mx.array(True))
|
||||
self.assertTrue(isinstance(b, dict))
|
||||
self.assertEqual(b["a"].item(), 1.0)
|
||||
self.assertEqual(b["b"].item(), 1.0)
|
||||
|
||||
c = fun(mx.array(1.0), None)
|
||||
self.assertTrue(isinstance(c, mx.array))
|
||||
self.assertEqual(c.item(), 2.0)
|
||||
|
||||
d = fun(False, mx.array(1.0))
|
||||
self.assertTrue(d is None)
|
||||
|
||||
def test_compile_changing_outputs_with_state(self):
|
||||
state = [mx.array(1.0)]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def fun(y):
|
||||
x = state[0]
|
||||
if y.dtype == mx.float32:
|
||||
state[0] = 2 * y
|
||||
return [x, y, x + y]
|
||||
elif y.dtype == mx.int32:
|
||||
state[0] *= 2
|
||||
return x + y
|
||||
|
||||
for i in range(10):
|
||||
fun(mx.array(1.0))
|
||||
fun(mx.array(1))
|
||||
|
||||
self.assertEqual(state[0].item(), 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
Reference in New Issue
Block a user