mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compile now can attach arbitrary data to an entry
This commit is contained in:
@@ -389,12 +389,6 @@ auto py_vmap(
|
||||
};
|
||||
}
|
||||
|
||||
std::unordered_map<std::uintptr_t, nb::object>& tree_cache() {
|
||||
// This map is used to Cache the tree structure of the outputs
|
||||
static std::unordered_map<std::uintptr_t, nb::object> tree_cache_;
|
||||
return tree_cache_;
|
||||
}
|
||||
|
||||
struct PyCompiledFun {
|
||||
nb::callable fun;
|
||||
std::uintptr_t fun_id;
|
||||
@@ -508,7 +502,10 @@ struct PyCompiledFun {
|
||||
auto [outputs, py_outputs] =
|
||||
tree_flatten_with_structure(std::move(tree_outputs), false);
|
||||
|
||||
tree_cache().insert({fun_id, py_outputs});
|
||||
std::shared_ptr<void> py_outputs_void(
|
||||
py_outputs.release().ptr(), [](void* handle) {
|
||||
nb::steal(reinterpret_cast<PyObject*>(handle)).reset();
|
||||
});
|
||||
|
||||
num_outputs = outputs.size();
|
||||
if (!captured_outputs.is_none()) {
|
||||
@@ -523,7 +520,7 @@ struct PyCompiledFun {
|
||||
if (!captured_inputs.is_none()) {
|
||||
tree_replace(captured_inputs, trace_captures, flat_in_captures);
|
||||
}
|
||||
return outputs;
|
||||
return mx::detail::ArraysAndExtra{outputs, py_outputs_void};
|
||||
};
|
||||
|
||||
if (!captured_inputs.is_none()) {
|
||||
@@ -535,7 +532,7 @@ struct PyCompiledFun {
|
||||
}
|
||||
|
||||
// Compile and call
|
||||
auto outputs =
|
||||
auto [outputs, py_outputs_void] =
|
||||
mx::detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
|
||||
if (!captured_outputs.is_none()) {
|
||||
std::vector<mx::array> captures(
|
||||
@@ -545,7 +542,8 @@ struct PyCompiledFun {
|
||||
}
|
||||
|
||||
// Put the outputs back in the container
|
||||
nb::object py_outputs = tree_cache().at(fun_id);
|
||||
nb::object py_outputs =
|
||||
nb::borrow(reinterpret_cast<PyObject*>(py_outputs_void.get()));
|
||||
return tree_unflatten_from_structure(py_outputs, outputs);
|
||||
}
|
||||
|
||||
@@ -556,7 +554,6 @@ struct PyCompiledFun {
|
||||
~PyCompiledFun() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
tree_cache().erase(fun_id);
|
||||
mx::detail::compile_erase(fun_id);
|
||||
fun.reset();
|
||||
captured_inputs.reset();
|
||||
@@ -1479,8 +1476,6 @@ void init_transforms(nb::module_& m) {
|
||||
|
||||
// Register static Python object cleanup before the interpreter exits
|
||||
auto atexit = nb::module_::import_("atexit");
|
||||
atexit.attr("register")(nb::cpp_function([]() {
|
||||
tree_cache().clear();
|
||||
mx::detail::compile_clear_cache();
|
||||
}));
|
||||
atexit.attr("register")(
|
||||
nb::cpp_function([]() { mx::detail::compile_clear_cache(); }));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user