mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compile now can attach arbitrary data to an entry
This commit is contained in:
@@ -296,6 +296,7 @@ class CompilerCache {
|
||||
std::vector<array> tape;
|
||||
bool empty{true};
|
||||
std::vector<uint64_t> constants;
|
||||
std::shared_ptr<void> extra;
|
||||
};
|
||||
|
||||
// Returns a reference to a CacheEntry which can be updated
|
||||
@@ -376,8 +377,9 @@ CompilerCache& compiler_cache() {
|
||||
return compiler_cache_;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<array>> compile_trace(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
std::tuple<std::vector<array>, std::vector<array>, std::shared_ptr<void>>
|
||||
compile_trace(
|
||||
const ArrayFnWithExtra& fun,
|
||||
const std::vector<array>& inputs,
|
||||
bool shapeless) {
|
||||
// Set the global tracing flag.
|
||||
@@ -391,7 +393,9 @@ std::pair<std::vector<array>, std::vector<array>> compile_trace(
|
||||
in.set_tracer(true);
|
||||
tracer_inputs.push_back(std::move(in));
|
||||
}
|
||||
return {tracer_inputs, fun(tracer_inputs)};
|
||||
|
||||
auto output = fun(tracer_inputs);
|
||||
return {tracer_inputs, output.first, output.second};
|
||||
}
|
||||
|
||||
// Traverses the graph to build a tape and a map of array ids to their parents
|
||||
@@ -932,8 +936,8 @@ bool skip_compile() {
|
||||
!(compile_available_for_device(default_device()));
|
||||
}
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||
ArrayFnWithExtra compile(
|
||||
ArrayFnWithExtra fun,
|
||||
std::uintptr_t fun_id,
|
||||
bool shapeless /* = false */,
|
||||
std::vector<uint64_t> constants /* = {} */) {
|
||||
@@ -966,7 +970,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
// Set the constants
|
||||
entry.constants = std::move(constants);
|
||||
// Trace to build the graph
|
||||
std::tie(entry.inputs, entry.outputs) =
|
||||
std::tie(entry.inputs, entry.outputs, entry.extra) =
|
||||
compile_trace(fun, inputs, shapeless);
|
||||
|
||||
// DFS the graph and get a tape, and a map of array id to (parent,
|
||||
@@ -991,8 +995,30 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
|
||||
// At this point we must have a tape, now replace the placeholders
|
||||
// with real arrays that can be evaluated
|
||||
return compile_replace(
|
||||
entry.tape, entry.inputs, entry.outputs, inputs, shapeless);
|
||||
return std::pair<std::vector<array>, std::shared_ptr<void>>{
|
||||
compile_replace(
|
||||
entry.tape, entry.inputs, entry.outputs, inputs, shapeless),
|
||||
entry.extra};
|
||||
};
|
||||
}
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||
std::uintptr_t fun_id,
|
||||
bool shapeless /* = false */,
|
||||
std::vector<uint64_t> constants /* = {} */) {
|
||||
ArrayFnWithExtra fun_with_extra =
|
||||
[fun = std::move(fun)](const std::vector<array>& inputs) {
|
||||
return std::pair<std::vector<array>, std::shared_ptr<void>>{
|
||||
fun(inputs), nullptr};
|
||||
};
|
||||
|
||||
auto compiled_fun = compile(
|
||||
std::move(fun_with_extra), fun_id, shapeless, std::move(constants));
|
||||
|
||||
return [compiled_fun =
|
||||
std::move(compiled_fun)](const std::vector<array>& inputs) {
|
||||
return compiled_fun(inputs).first;
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,10 @@
|
||||
|
||||
namespace mlx::core::detail {
|
||||
|
||||
using ArraysAndExtra = std::pair<std::vector<array>, std::shared_ptr<void>>;
|
||||
using ArrayFnWithExtra =
|
||||
std::function<ArraysAndExtra(const std::vector<array>&)>;
|
||||
|
||||
// This is not part of the general C++ API as calling with a bad id is a bad
|
||||
// idea.
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
@@ -16,6 +20,12 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
bool shapeless = false,
|
||||
std::vector<uint64_t> constants = {});
|
||||
|
||||
ArrayFnWithExtra compile(
|
||||
ArrayFnWithExtra fun,
|
||||
std::uintptr_t fun_id,
|
||||
bool shapeless,
|
||||
std::vector<uint64_t> constants);
|
||||
|
||||
// Erase cached compile functions
|
||||
void compile_erase(std::uintptr_t fun_id);
|
||||
|
||||
@@ -25,8 +35,9 @@ void compile_clear_cache();
|
||||
|
||||
bool compile_available_for_device(const Device& device);
|
||||
|
||||
std::pair<std::vector<array>, std::vector<array>> compile_trace(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
std::tuple<std::vector<array>, std::vector<array>, std::shared_ptr<void>>
|
||||
compile_trace(
|
||||
const ArrayFnWithExtra& fun,
|
||||
const std::vector<array>& inputs,
|
||||
bool shapeless);
|
||||
|
||||
|
||||
@@ -579,11 +579,11 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
||||
for (auto& k : kwarg_keys) {
|
||||
kwargs.insert({k, *it++});
|
||||
}
|
||||
return fun(args, kwargs);
|
||||
return detail::ArraysAndExtra{fun(args, kwargs), nullptr};
|
||||
};
|
||||
|
||||
// Trace to build the graph
|
||||
auto [trace_inputs, trace_outputs] =
|
||||
auto [trace_inputs, trace_outputs, extra] =
|
||||
detail::compile_trace(flat_fun, inputs, ftable->shapeless);
|
||||
|
||||
// DFS the graph and get the tape
|
||||
|
||||
@@ -389,12 +389,6 @@ auto py_vmap(
|
||||
};
|
||||
}
|
||||
|
||||
std::unordered_map<std::uintptr_t, nb::object>& tree_cache() {
|
||||
// This map is used to Cache the tree structure of the outputs
|
||||
static std::unordered_map<std::uintptr_t, nb::object> tree_cache_;
|
||||
return tree_cache_;
|
||||
}
|
||||
|
||||
struct PyCompiledFun {
|
||||
nb::callable fun;
|
||||
std::uintptr_t fun_id;
|
||||
@@ -508,7 +502,10 @@ struct PyCompiledFun {
|
||||
auto [outputs, py_outputs] =
|
||||
tree_flatten_with_structure(std::move(tree_outputs), false);
|
||||
|
||||
tree_cache().insert({fun_id, py_outputs});
|
||||
std::shared_ptr<void> py_outputs_void(
|
||||
py_outputs.release().ptr(), [](void* handle) {
|
||||
nb::steal(reinterpret_cast<PyObject*>(handle)).reset();
|
||||
});
|
||||
|
||||
num_outputs = outputs.size();
|
||||
if (!captured_outputs.is_none()) {
|
||||
@@ -523,7 +520,7 @@ struct PyCompiledFun {
|
||||
if (!captured_inputs.is_none()) {
|
||||
tree_replace(captured_inputs, trace_captures, flat_in_captures);
|
||||
}
|
||||
return outputs;
|
||||
return mx::detail::ArraysAndExtra{outputs, py_outputs_void};
|
||||
};
|
||||
|
||||
if (!captured_inputs.is_none()) {
|
||||
@@ -535,7 +532,7 @@ struct PyCompiledFun {
|
||||
}
|
||||
|
||||
// Compile and call
|
||||
auto outputs =
|
||||
auto [outputs, py_outputs_void] =
|
||||
mx::detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
|
||||
if (!captured_outputs.is_none()) {
|
||||
std::vector<mx::array> captures(
|
||||
@@ -545,7 +542,8 @@ struct PyCompiledFun {
|
||||
}
|
||||
|
||||
// Put the outputs back in the container
|
||||
nb::object py_outputs = tree_cache().at(fun_id);
|
||||
nb::object py_outputs =
|
||||
nb::borrow(reinterpret_cast<PyObject*>(py_outputs_void.get()));
|
||||
return tree_unflatten_from_structure(py_outputs, outputs);
|
||||
}
|
||||
|
||||
@@ -556,7 +554,6 @@ struct PyCompiledFun {
|
||||
~PyCompiledFun() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
tree_cache().erase(fun_id);
|
||||
mx::detail::compile_erase(fun_id);
|
||||
fun.reset();
|
||||
captured_inputs.reset();
|
||||
@@ -1479,8 +1476,6 @@ void init_transforms(nb::module_& m) {
|
||||
|
||||
// Register static Python object cleanup before the interpreter exits
|
||||
auto atexit = nb::module_::import_("atexit");
|
||||
atexit.attr("register")(nb::cpp_function([]() {
|
||||
tree_cache().clear();
|
||||
mx::detail::compile_clear_cache();
|
||||
}));
|
||||
atexit.attr("register")(
|
||||
nb::cpp_function([]() { mx::detail::compile_clear_cache(); }));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user