From 3a7ad1b65b20fd034270b409f418e9e4a8763cac Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 30 Sep 2025 01:10:51 -0700 Subject: [PATCH] Fix for captured state --- python/src/transforms.cpp | 33 +++++++++++++++++++++------------ python/tests/test_compile.py | 19 +++++++++++++++++++ 2 files changed, 40 insertions(+), 12 deletions(-) diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index dbbc3fd54..12aa641f8 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -395,7 +395,16 @@ struct PyCompiledFun { nb::object captured_inputs; nb::object captured_outputs; bool shapeless; - mutable size_t num_outputs{0}; + + // Data to attach to the compiled function that contains the python output + // structure and the number of arrays in said structure. + struct AttachedData { + nb::object output_structure; + int num_outputs; + + AttachedData(nb::object output_structure_, int num_outputs_) + : output_structure(output_structure_), num_outputs(num_outputs_) {} + }; PyCompiledFun( const nb::callable& fun, @@ -418,7 +427,6 @@ struct PyCompiledFun { captured_inputs = std::move(other.captured_inputs); captured_outputs = std::move(other.captured_outputs); shapeless = other.shapeless; - num_outputs = other.num_outputs; }; nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) { @@ -502,12 +510,9 @@ struct PyCompiledFun { auto [outputs, py_outputs] = tree_flatten_with_structure(std::move(tree_outputs), false); - std::shared_ptr py_outputs_void( - py_outputs.release().ptr(), [](void* handle) { - nb::steal(reinterpret_cast(handle)).reset(); - }); + std::shared_ptr extra_data = + std::make_shared(py_outputs, outputs.size()); - num_outputs = outputs.size(); if (!captured_outputs.is_none()) { auto flat_out_captures = tree_flatten(captured_outputs, false); outputs.insert( @@ -520,7 +525,7 @@ struct PyCompiledFun { if (!captured_inputs.is_none()) { tree_replace(captured_inputs, trace_captures, flat_in_captures); } - return mx::detail::ArraysAndExtra{outputs, py_outputs_void}; + return mx::detail::ArraysAndExtra{outputs, extra_data}; }; if (!captured_inputs.is_none()) { @@ -532,8 +537,14 @@ struct PyCompiledFun { } // Compile and call - auto [outputs, py_outputs_void] = + auto [outputs, extra_data] = mx::detail::compile(compile_fun, fun_id, shapeless, constants)(inputs); + + int num_outputs = + reinterpret_cast(extra_data.get())->num_outputs; + nb::object py_outputs = + reinterpret_cast(extra_data.get())->output_structure; + if (!captured_outputs.is_none()) { std::vector captures( std::make_move_iterator(outputs.begin() + num_outputs), @@ -542,9 +553,7 @@ struct PyCompiledFun { } // Put the outputs back in the container - nb::object py_outputs = - nb::borrow(reinterpret_cast(py_outputs_void.get())); - return tree_unflatten_from_structure(py_outputs, outputs); + return tree_unflatten_from_structure(std::move(py_outputs), outputs); } nb::object operator()(const nb::args& args, const nb::kwargs& kwargs) const { diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 65195808c..572123c60 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -1096,6 +1096,25 @@ class TestCompile(mlx_tests.MLXTestCase): d = fun(False, mx.array(1.0)) self.assertTrue(d is None) + def test_compile_changing_outputs_with_state(self): + state = [mx.array(1.0)] + + @partial(mx.compile, inputs=state, outputs=state) + def fun(y): + x = state[0] + if y.dtype == mx.float32: + state[0] = 2 * y + return [x, y, x + y] + elif y.dtype == mx.int32: + state[0] *= 2 + return x + y + + for i in range(10): + fun(mx.array(1.0)) + fun(mx.array(1)) + + self.assertEqual(state[0].item(), 4) + if __name__ == "__main__": mlx_tests.MLXTestRunner()