Compile now can attach arbitrary data to an entry

This commit is contained in:
Angelos Katharopoulos
2025-09-29 23:21:50 -07:00
parent dc371ae7a5
commit 7cdb155d17
4 changed files with 59 additions and 27 deletions

View File

@@ -389,12 +389,6 @@ auto py_vmap(
};
}
std::unordered_map<std::uintptr_t, nb::object>& tree_cache() {
// This map is used to Cache the tree structure of the outputs
static std::unordered_map<std::uintptr_t, nb::object> tree_cache_;
return tree_cache_;
}
struct PyCompiledFun {
nb::callable fun;
std::uintptr_t fun_id;
@@ -508,7 +502,10 @@ struct PyCompiledFun {
auto [outputs, py_outputs] =
tree_flatten_with_structure(std::move(tree_outputs), false);
tree_cache().insert({fun_id, py_outputs});
std::shared_ptr<void> py_outputs_void(
py_outputs.release().ptr(), [](void* handle) {
nb::steal(reinterpret_cast<PyObject*>(handle)).reset();
});
num_outputs = outputs.size();
if (!captured_outputs.is_none()) {
@@ -523,7 +520,7 @@ struct PyCompiledFun {
if (!captured_inputs.is_none()) {
tree_replace(captured_inputs, trace_captures, flat_in_captures);
}
return outputs;
return mx::detail::ArraysAndExtra{outputs, py_outputs_void};
};
if (!captured_inputs.is_none()) {
@@ -535,7 +532,7 @@ struct PyCompiledFun {
}
// Compile and call
auto outputs =
auto [outputs, py_outputs_void] =
mx::detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
if (!captured_outputs.is_none()) {
std::vector<mx::array> captures(
@@ -545,7 +542,8 @@ struct PyCompiledFun {
}
// Put the outputs back in the container
nb::object py_outputs = tree_cache().at(fun_id);
nb::object py_outputs =
nb::borrow(reinterpret_cast<PyObject*>(py_outputs_void.get()));
return tree_unflatten_from_structure(py_outputs, outputs);
}
@@ -556,7 +554,6 @@ struct PyCompiledFun {
~PyCompiledFun() {
nb::gil_scoped_acquire gil;
tree_cache().erase(fun_id);
mx::detail::compile_erase(fun_id);
fun.reset();
captured_inputs.reset();
@@ -1479,8 +1476,6 @@ void init_transforms(nb::module_& m) {
// Register static Python object cleanup before the interpreter exits
auto atexit = nb::module_::import_("atexit");
atexit.attr("register")(nb::cpp_function([]() {
tree_cache().clear();
mx::detail::compile_clear_cache();
}));
atexit.attr("register")(
nb::cpp_function([]() { mx::detail::compile_clear_cache(); }));
}