mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix for captured state
This commit is contained in:
@@ -1096,6 +1096,25 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
d = fun(False, mx.array(1.0))
|
||||
self.assertTrue(d is None)
|
||||
|
||||
def test_compile_changing_outputs_with_state(self):
|
||||
state = [mx.array(1.0)]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def fun(y):
|
||||
x = state[0]
|
||||
if y.dtype == mx.float32:
|
||||
state[0] = 2 * y
|
||||
return [x, y, x + y]
|
||||
elif y.dtype == mx.int32:
|
||||
state[0] *= 2
|
||||
return x + y
|
||||
|
||||
for i in range(10):
|
||||
fun(mx.array(1.0))
|
||||
fun(mx.array(1))
|
||||
|
||||
self.assertEqual(state[0].item(), 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
||||
Reference in New Issue
Block a user