Fix for captured state

This commit is contained in:
Angelos Katharopoulos
2025-09-30 01:10:51 -07:00
parent cab27b7f0d
commit 3a7ad1b65b
2 changed files with 40 additions and 12 deletions

View File

@@ -1096,6 +1096,25 @@ class TestCompile(mlx_tests.MLXTestCase):
d = fun(False, mx.array(1.0))
self.assertTrue(d is None)
def test_compile_changing_outputs_with_state(self):
state = [mx.array(1.0)]
@partial(mx.compile, inputs=state, outputs=state)
def fun(y):
x = state[0]
if y.dtype == mx.float32:
state[0] = 2 * y
return [x, y, x + y]
elif y.dtype == mx.int32:
state[0] *= 2
return x + y
for i in range(10):
fun(mx.array(1.0))
fun(mx.array(1))
self.assertEqual(state[0].item(), 4)
if __name__ == "__main__":
mlx_tests.MLXTestRunner()