From eb24267b568e8ee1ac46571b642eb7b9ba8ee69f Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 30 Sep 2025 13:33:27 -0700 Subject: [PATCH] Compile now can attach arbitrary data to an entry (#2634) --- mlx/compile.cpp | 49 ++++++++++++++++++++++++++++------ mlx/compile_impl.h | 15 +++++++++-- mlx/export.cpp | 4 +-- python/src/transforms.cpp | 42 +++++++++++++++-------------- python/tests/test_compile.py | 51 ++++++++++++++++++++++++++++++++++++ 5 files changed, 130 insertions(+), 31 deletions(-) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 90eeaa95a..e6cd58755 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -296,6 +296,7 @@ class CompilerCache { std::vector tape; bool empty{true}; std::vector constants; + std::shared_ptr extra; }; // Returns a reference to a CacheEntry which can be updated @@ -376,8 +377,9 @@ CompilerCache& compiler_cache() { return compiler_cache_; } -std::pair, std::vector> compile_trace( - const std::function(const std::vector&)>& fun, +std::tuple, std::vector, std::shared_ptr> +compile_trace( + const ArrayFnWithExtra& fun, const std::vector& inputs, bool shapeless) { // Set the global tracing flag. @@ -391,7 +393,9 @@ std::pair, std::vector> compile_trace( in.set_tracer(true); tracer_inputs.push_back(std::move(in)); } - return {tracer_inputs, fun(tracer_inputs)}; + + auto output = fun(tracer_inputs); + return {tracer_inputs, output.first, output.second}; } // Traverses the graph to build a tape and a map of array ids to their parents @@ -932,8 +936,8 @@ bool skip_compile() { !(compile_available_for_device(default_device())); } -std::function(const std::vector&)> compile( - std::function(const std::vector&)> fun, +ArrayFnWithExtra compile( + ArrayFnWithExtra fun, std::uintptr_t fun_id, bool shapeless /* = false */, std::vector constants /* = {} */) { @@ -966,7 +970,7 @@ std::function(const std::vector&)> compile( // Set the constants entry.constants = std::move(constants); // Trace to build the graph - std::tie(entry.inputs, entry.outputs) = + std::tie(entry.inputs, entry.outputs, entry.extra) = compile_trace(fun, inputs, shapeless); // DFS the graph and get a tape, and a map of array id to (parent, @@ -991,8 +995,37 @@ std::function(const std::vector&)> compile( // At this point we must have a tape, now replace the placeholders // with real arrays that can be evaluated - return compile_replace( - entry.tape, entry.inputs, entry.outputs, inputs, shapeless); + return ArraysAndExtra{ + compile_replace( + entry.tape, entry.inputs, entry.outputs, inputs, shapeless), + entry.extra}; + }; +} + +std::function(const std::vector&)> compile( + std::function(const std::vector&)> fun, + std::uintptr_t fun_id, + bool shapeless /* = false */, + std::vector constants /* = {} */) { + if (skip_compile()) { + return fun; + } + if (!fun) { + throw std::invalid_argument( + "[compile] Cannot compile a function without a target."); + } + + ArrayFnWithExtra fun_with_extra = + [fun = std::move(fun)](const std::vector& inputs) { + return ArraysAndExtra{fun(inputs), nullptr}; + }; + + auto compiled_fun = compile( + std::move(fun_with_extra), fun_id, shapeless, std::move(constants)); + + return [compiled_fun = + std::move(compiled_fun)](const std::vector& inputs) { + return compiled_fun(inputs).first; }; } diff --git a/mlx/compile_impl.h b/mlx/compile_impl.h index bc3f44b4e..1b48c88aa 100644 --- a/mlx/compile_impl.h +++ b/mlx/compile_impl.h @@ -8,6 +8,10 @@ namespace mlx::core::detail { +using ArraysAndExtra = std::pair, std::shared_ptr>; +using ArrayFnWithExtra = + std::function&)>; + // This is not part of the general C++ API as calling with a bad id is a bad // idea. std::function(const std::vector&)> compile( @@ -16,6 +20,12 @@ std::function(const std::vector&)> compile( bool shapeless = false, std::vector constants = {}); +ArrayFnWithExtra compile( + ArrayFnWithExtra fun, + std::uintptr_t fun_id, + bool shapeless, + std::vector constants); + // Erase cached compile functions void compile_erase(std::uintptr_t fun_id); @@ -25,8 +35,9 @@ void compile_clear_cache(); bool compile_available_for_device(const Device& device); -std::pair, std::vector> compile_trace( - const std::function(const std::vector&)>& fun, +std::tuple, std::vector, std::shared_ptr> +compile_trace( + const ArrayFnWithExtra& fun, const std::vector& inputs, bool shapeless); diff --git a/mlx/export.cpp b/mlx/export.cpp index 19944dfc4..ca7e316b3 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -579,11 +579,11 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) { for (auto& k : kwarg_keys) { kwargs.insert({k, *it++}); } - return fun(args, kwargs); + return detail::ArraysAndExtra{fun(args, kwargs), nullptr}; }; // Trace to build the graph - auto [trace_inputs, trace_outputs] = + auto [trace_inputs, trace_outputs, extra] = detail::compile_trace(flat_fun, inputs, ftable->shapeless); // DFS the graph and get the tape diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 8c8c16316..12aa641f8 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -389,19 +389,22 @@ auto py_vmap( }; } -std::unordered_map& tree_cache() { - // This map is used to Cache the tree structure of the outputs - static std::unordered_map tree_cache_; - return tree_cache_; -} - struct PyCompiledFun { nb::callable fun; std::uintptr_t fun_id; nb::object captured_inputs; nb::object captured_outputs; bool shapeless; - mutable size_t num_outputs{0}; + + // Data to attach to the compiled function that contains the python output + // structure and the number of arrays in said structure. + struct AttachedData { + nb::object output_structure; + int num_outputs; + + AttachedData(nb::object output_structure_, int num_outputs_) + : output_structure(output_structure_), num_outputs(num_outputs_) {} + }; PyCompiledFun( const nb::callable& fun, @@ -424,7 +427,6 @@ struct PyCompiledFun { captured_inputs = std::move(other.captured_inputs); captured_outputs = std::move(other.captured_outputs); shapeless = other.shapeless; - num_outputs = other.num_outputs; }; nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) { @@ -508,9 +510,9 @@ struct PyCompiledFun { auto [outputs, py_outputs] = tree_flatten_with_structure(std::move(tree_outputs), false); - tree_cache().insert({fun_id, py_outputs}); + std::shared_ptr extra_data = + std::make_shared(py_outputs, outputs.size()); - num_outputs = outputs.size(); if (!captured_outputs.is_none()) { auto flat_out_captures = tree_flatten(captured_outputs, false); outputs.insert( @@ -523,7 +525,7 @@ struct PyCompiledFun { if (!captured_inputs.is_none()) { tree_replace(captured_inputs, trace_captures, flat_in_captures); } - return outputs; + return mx::detail::ArraysAndExtra{outputs, extra_data}; }; if (!captured_inputs.is_none()) { @@ -535,8 +537,14 @@ struct PyCompiledFun { } // Compile and call - auto outputs = + auto [outputs, extra_data] = mx::detail::compile(compile_fun, fun_id, shapeless, constants)(inputs); + + int num_outputs = + reinterpret_cast(extra_data.get())->num_outputs; + nb::object py_outputs = + reinterpret_cast(extra_data.get())->output_structure; + if (!captured_outputs.is_none()) { std::vector captures( std::make_move_iterator(outputs.begin() + num_outputs), @@ -545,8 +553,7 @@ struct PyCompiledFun { } // Put the outputs back in the container - nb::object py_outputs = tree_cache().at(fun_id); - return tree_unflatten_from_structure(py_outputs, outputs); + return tree_unflatten_from_structure(std::move(py_outputs), outputs); } nb::object operator()(const nb::args& args, const nb::kwargs& kwargs) const { @@ -556,7 +563,6 @@ struct PyCompiledFun { ~PyCompiledFun() { nb::gil_scoped_acquire gil; - tree_cache().erase(fun_id); mx::detail::compile_erase(fun_id); fun.reset(); captured_inputs.reset(); @@ -1479,8 +1485,6 @@ void init_transforms(nb::module_& m) { // Register static Python object cleanup before the interpreter exits auto atexit = nb::module_::import_("atexit"); - atexit.attr("register")(nb::cpp_function([]() { - tree_cache().clear(); - mx::detail::compile_clear_cache(); - })); + atexit.attr("register")( + nb::cpp_function([]() { mx::detail::compile_clear_cache(); })); } diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 5528b094e..572123c60 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -1064,6 +1064,57 @@ class TestCompile(mlx_tests.MLXTestCase): out = fun(mx.array(1.0), mx.array(2.0)) self.assertEqual(out.item(), 3.0) + def test_compile_changing_outputs(self): + @mx.compile + def fun(x, y): + if y is None: + return 2 * x + elif ( + isinstance(x, mx.array) + and isinstance(y, mx.array) + and x.dtype == y.dtype == mx.float32 + ): + return [x + y] + elif y.dtype == mx.bool_: + return {"a": x, "b": y * x} + else: + return None + + a = fun(mx.array(1.0), mx.array(2.0)) + self.assertTrue(isinstance(a, list)) + self.assertEqual(a[0].item(), 3.0) + + b = fun(mx.array(1.0), mx.array(True)) + self.assertTrue(isinstance(b, dict)) + self.assertEqual(b["a"].item(), 1.0) + self.assertEqual(b["b"].item(), 1.0) + + c = fun(mx.array(1.0), None) + self.assertTrue(isinstance(c, mx.array)) + self.assertEqual(c.item(), 2.0) + + d = fun(False, mx.array(1.0)) + self.assertTrue(d is None) + + def test_compile_changing_outputs_with_state(self): + state = [mx.array(1.0)] + + @partial(mx.compile, inputs=state, outputs=state) + def fun(y): + x = state[0] + if y.dtype == mx.float32: + state[0] = 2 * y + return [x, y, x + y] + elif y.dtype == mx.int32: + state[0] *= 2 + return x + y + + for i in range(10): + fun(mx.array(1.0)) + fun(mx.array(1)) + + self.assertEqual(state[0].item(), 4) + if __name__ == "__main__": mlx_tests.MLXTestRunner()