From 7cdb155d17c6033f4ce211be9be129f47af4f477 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 29 Sep 2025 23:21:50 -0700 Subject: [PATCH] Compile now can attach arbitrary data to an entry --- mlx/compile.cpp | 42 +++++++++++++++++++++++++++++++-------- mlx/compile_impl.h | 15 ++++++++++++-- mlx/export.cpp | 4 ++-- python/src/transforms.cpp | 25 ++++++++++------------- 4 files changed, 59 insertions(+), 27 deletions(-) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 90eeaa95a..6e0be94a8 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -296,6 +296,7 @@ class CompilerCache { std::vector tape; bool empty{true}; std::vector constants; + std::shared_ptr extra; }; // Returns a reference to a CacheEntry which can be updated @@ -376,8 +377,9 @@ CompilerCache& compiler_cache() { return compiler_cache_; } -std::pair, std::vector> compile_trace( - const std::function(const std::vector&)>& fun, +std::tuple, std::vector, std::shared_ptr> +compile_trace( + const ArrayFnWithExtra& fun, const std::vector& inputs, bool shapeless) { // Set the global tracing flag. @@ -391,7 +393,9 @@ std::pair, std::vector> compile_trace( in.set_tracer(true); tracer_inputs.push_back(std::move(in)); } - return {tracer_inputs, fun(tracer_inputs)}; + + auto output = fun(tracer_inputs); + return {tracer_inputs, output.first, output.second}; } // Traverses the graph to build a tape and a map of array ids to their parents @@ -932,8 +936,8 @@ bool skip_compile() { !(compile_available_for_device(default_device())); } -std::function(const std::vector&)> compile( - std::function(const std::vector&)> fun, +ArrayFnWithExtra compile( + ArrayFnWithExtra fun, std::uintptr_t fun_id, bool shapeless /* = false */, std::vector constants /* = {} */) { @@ -966,7 +970,7 @@ std::function(const std::vector&)> compile( // Set the constants entry.constants = std::move(constants); // Trace to build the graph - std::tie(entry.inputs, entry.outputs) = + std::tie(entry.inputs, entry.outputs, entry.extra) = compile_trace(fun, inputs, shapeless); // DFS the graph and get a tape, and a map of array id to (parent, @@ -991,8 +995,30 @@ std::function(const std::vector&)> compile( // At this point we must have a tape, now replace the placeholders // with real arrays that can be evaluated - return compile_replace( - entry.tape, entry.inputs, entry.outputs, inputs, shapeless); + return std::pair, std::shared_ptr>{ + compile_replace( + entry.tape, entry.inputs, entry.outputs, inputs, shapeless), + entry.extra}; + }; +} + +std::function(const std::vector&)> compile( + std::function(const std::vector&)> fun, + std::uintptr_t fun_id, + bool shapeless /* = false */, + std::vector constants /* = {} */) { + ArrayFnWithExtra fun_with_extra = + [fun = std::move(fun)](const std::vector& inputs) { + return std::pair, std::shared_ptr>{ + fun(inputs), nullptr}; + }; + + auto compiled_fun = compile( + std::move(fun_with_extra), fun_id, shapeless, std::move(constants)); + + return [compiled_fun = + std::move(compiled_fun)](const std::vector& inputs) { + return compiled_fun(inputs).first; }; } diff --git a/mlx/compile_impl.h b/mlx/compile_impl.h index bc3f44b4e..1b48c88aa 100644 --- a/mlx/compile_impl.h +++ b/mlx/compile_impl.h @@ -8,6 +8,10 @@ namespace mlx::core::detail { +using ArraysAndExtra = std::pair, std::shared_ptr>; +using ArrayFnWithExtra = + std::function&)>; + // This is not part of the general C++ API as calling with a bad id is a bad // idea. std::function(const std::vector&)> compile( @@ -16,6 +20,12 @@ std::function(const std::vector&)> compile( bool shapeless = false, std::vector constants = {}); +ArrayFnWithExtra compile( + ArrayFnWithExtra fun, + std::uintptr_t fun_id, + bool shapeless, + std::vector constants); + // Erase cached compile functions void compile_erase(std::uintptr_t fun_id); @@ -25,8 +35,9 @@ void compile_clear_cache(); bool compile_available_for_device(const Device& device); -std::pair, std::vector> compile_trace( - const std::function(const std::vector&)>& fun, +std::tuple, std::vector, std::shared_ptr> +compile_trace( + const ArrayFnWithExtra& fun, const std::vector& inputs, bool shapeless); diff --git a/mlx/export.cpp b/mlx/export.cpp index 19944dfc4..ca7e316b3 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -579,11 +579,11 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) { for (auto& k : kwarg_keys) { kwargs.insert({k, *it++}); } - return fun(args, kwargs); + return detail::ArraysAndExtra{fun(args, kwargs), nullptr}; }; // Trace to build the graph - auto [trace_inputs, trace_outputs] = + auto [trace_inputs, trace_outputs, extra] = detail::compile_trace(flat_fun, inputs, ftable->shapeless); // DFS the graph and get the tape diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 8c8c16316..dbbc3fd54 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -389,12 +389,6 @@ auto py_vmap( }; } -std::unordered_map& tree_cache() { - // This map is used to Cache the tree structure of the outputs - static std::unordered_map tree_cache_; - return tree_cache_; -} - struct PyCompiledFun { nb::callable fun; std::uintptr_t fun_id; @@ -508,7 +502,10 @@ struct PyCompiledFun { auto [outputs, py_outputs] = tree_flatten_with_structure(std::move(tree_outputs), false); - tree_cache().insert({fun_id, py_outputs}); + std::shared_ptr py_outputs_void( + py_outputs.release().ptr(), [](void* handle) { + nb::steal(reinterpret_cast(handle)).reset(); + }); num_outputs = outputs.size(); if (!captured_outputs.is_none()) { @@ -523,7 +520,7 @@ struct PyCompiledFun { if (!captured_inputs.is_none()) { tree_replace(captured_inputs, trace_captures, flat_in_captures); } - return outputs; + return mx::detail::ArraysAndExtra{outputs, py_outputs_void}; }; if (!captured_inputs.is_none()) { @@ -535,7 +532,7 @@ struct PyCompiledFun { } // Compile and call - auto outputs = + auto [outputs, py_outputs_void] = mx::detail::compile(compile_fun, fun_id, shapeless, constants)(inputs); if (!captured_outputs.is_none()) { std::vector captures( @@ -545,7 +542,8 @@ struct PyCompiledFun { } // Put the outputs back in the container - nb::object py_outputs = tree_cache().at(fun_id); + nb::object py_outputs = + nb::borrow(reinterpret_cast(py_outputs_void.get())); return tree_unflatten_from_structure(py_outputs, outputs); } @@ -556,7 +554,6 @@ struct PyCompiledFun { ~PyCompiledFun() { nb::gil_scoped_acquire gil; - tree_cache().erase(fun_id); mx::detail::compile_erase(fun_id); fun.reset(); captured_inputs.reset(); @@ -1479,8 +1476,6 @@ void init_transforms(nb::module_& m) { // Register static Python object cleanup before the interpreter exits auto atexit = nb::module_::import_("atexit"); - atexit.attr("register")(nb::cpp_function([]() { - tree_cache().clear(); - mx::detail::compile_clear_cache(); - })); + atexit.attr("register")( + nb::cpp_function([]() { mx::detail::compile_clear_cache(); })); }