Compile now can attach arbitrary data to an entry (#2634)

This commit is contained in:
Angelos Katharopoulos
2025-09-30 13:33:27 -07:00
committed by GitHub
parent dc371ae7a5
commit eb24267b56
5 changed files with 130 additions and 31 deletions

View File

@@ -296,6 +296,7 @@ class CompilerCache {
std::vector<array> tape;
bool empty{true};
std::vector<uint64_t> constants;
std::shared_ptr<void> extra;
};
// Returns a reference to a CacheEntry which can be updated
@@ -376,8 +377,9 @@ CompilerCache& compiler_cache() {
return compiler_cache_;
}
std::pair<std::vector<array>, std::vector<array>> compile_trace(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
std::tuple<std::vector<array>, std::vector<array>, std::shared_ptr<void>>
compile_trace(
const ArrayFnWithExtra& fun,
const std::vector<array>& inputs,
bool shapeless) {
// Set the global tracing flag.
@@ -391,7 +393,9 @@ std::pair<std::vector<array>, std::vector<array>> compile_trace(
in.set_tracer(true);
tracer_inputs.push_back(std::move(in));
}
return {tracer_inputs, fun(tracer_inputs)};
auto output = fun(tracer_inputs);
return {tracer_inputs, output.first, output.second};
}
// Traverses the graph to build a tape and a map of array ids to their parents
@@ -932,8 +936,8 @@ bool skip_compile() {
!(compile_available_for_device(default_device()));
}
std::function<std::vector<array>(const std::vector<array>&)> compile(
std::function<std::vector<array>(const std::vector<array>&)> fun,
ArrayFnWithExtra compile(
ArrayFnWithExtra fun,
std::uintptr_t fun_id,
bool shapeless /* = false */,
std::vector<uint64_t> constants /* = {} */) {
@@ -966,7 +970,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
// Set the constants
entry.constants = std::move(constants);
// Trace to build the graph
std::tie(entry.inputs, entry.outputs) =
std::tie(entry.inputs, entry.outputs, entry.extra) =
compile_trace(fun, inputs, shapeless);
// DFS the graph and get a tape, and a map of array id to (parent,
@@ -991,8 +995,37 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
// At this point we must have a tape, now replace the placeholders
// with real arrays that can be evaluated
return compile_replace(
entry.tape, entry.inputs, entry.outputs, inputs, shapeless);
return ArraysAndExtra{
compile_replace(
entry.tape, entry.inputs, entry.outputs, inputs, shapeless),
entry.extra};
};
}
std::function<std::vector<array>(const std::vector<array>&)> compile(
std::function<std::vector<array>(const std::vector<array>&)> fun,
std::uintptr_t fun_id,
bool shapeless /* = false */,
std::vector<uint64_t> constants /* = {} */) {
if (skip_compile()) {
return fun;
}
if (!fun) {
throw std::invalid_argument(
"[compile] Cannot compile a function without a target.");
}
ArrayFnWithExtra fun_with_extra =
[fun = std::move(fun)](const std::vector<array>& inputs) {
return ArraysAndExtra{fun(inputs), nullptr};
};
auto compiled_fun = compile(
std::move(fun_with_extra), fun_id, shapeless, std::move(constants));
return [compiled_fun =
std::move(compiled_fun)](const std::vector<array>& inputs) {
return compiled_fun(inputs).first;
};
}

View File

@@ -8,6 +8,10 @@
namespace mlx::core::detail {
using ArraysAndExtra = std::pair<std::vector<array>, std::shared_ptr<void>>;
using ArrayFnWithExtra =
std::function<ArraysAndExtra(const std::vector<array>&)>;
// This is not part of the general C++ API as calling with a bad id is a bad
// idea.
std::function<std::vector<array>(const std::vector<array>&)> compile(
@@ -16,6 +20,12 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
bool shapeless = false,
std::vector<uint64_t> constants = {});
ArrayFnWithExtra compile(
ArrayFnWithExtra fun,
std::uintptr_t fun_id,
bool shapeless,
std::vector<uint64_t> constants);
// Erase cached compile functions
void compile_erase(std::uintptr_t fun_id);
@@ -25,8 +35,9 @@ void compile_clear_cache();
bool compile_available_for_device(const Device& device);
std::pair<std::vector<array>, std::vector<array>> compile_trace(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
std::tuple<std::vector<array>, std::vector<array>, std::shared_ptr<void>>
compile_trace(
const ArrayFnWithExtra& fun,
const std::vector<array>& inputs,
bool shapeless);

View File

@@ -579,11 +579,11 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
for (auto& k : kwarg_keys) {
kwargs.insert({k, *it++});
}
return fun(args, kwargs);
return detail::ArraysAndExtra{fun(args, kwargs), nullptr};
};
// Trace to build the graph
auto [trace_inputs, trace_outputs] =
auto [trace_inputs, trace_outputs, extra] =
detail::compile_trace(flat_fun, inputs, ftable->shapeless);
// DFS the graph and get the tape