Compile now can attach arbitrary data to an entry (#2634)

This commit is contained in:
Angelos Katharopoulos
2025-09-30 13:33:27 -07:00
committed by GitHub
parent dc371ae7a5
commit eb24267b56
5 changed files with 130 additions and 31 deletions

View File

@@ -296,6 +296,7 @@ class CompilerCache {
std::vector<array> tape;
bool empty{true};
std::vector<uint64_t> constants;
std::shared_ptr<void> extra;
};
// Returns a reference to a CacheEntry which can be updated
@@ -376,8 +377,9 @@ CompilerCache& compiler_cache() {
return compiler_cache_;
}
std::pair<std::vector<array>, std::vector<array>> compile_trace(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
std::tuple<std::vector<array>, std::vector<array>, std::shared_ptr<void>>
compile_trace(
const ArrayFnWithExtra& fun,
const std::vector<array>& inputs,
bool shapeless) {
// Set the global tracing flag.
@@ -391,7 +393,9 @@ std::pair<std::vector<array>, std::vector<array>> compile_trace(
in.set_tracer(true);
tracer_inputs.push_back(std::move(in));
}
return {tracer_inputs, fun(tracer_inputs)};
auto output = fun(tracer_inputs);
return {tracer_inputs, output.first, output.second};
}
// Traverses the graph to build a tape and a map of array ids to their parents
@@ -932,8 +936,8 @@ bool skip_compile() {
!(compile_available_for_device(default_device()));
}
std::function<std::vector<array>(const std::vector<array>&)> compile(
std::function<std::vector<array>(const std::vector<array>&)> fun,
ArrayFnWithExtra compile(
ArrayFnWithExtra fun,
std::uintptr_t fun_id,
bool shapeless /* = false */,
std::vector<uint64_t> constants /* = {} */) {
@@ -966,7 +970,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
// Set the constants
entry.constants = std::move(constants);
// Trace to build the graph
std::tie(entry.inputs, entry.outputs) =
std::tie(entry.inputs, entry.outputs, entry.extra) =
compile_trace(fun, inputs, shapeless);
// DFS the graph and get a tape, and a map of array id to (parent,
@@ -991,8 +995,37 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
// At this point we must have a tape, now replace the placeholders
// with real arrays that can be evaluated
return compile_replace(
entry.tape, entry.inputs, entry.outputs, inputs, shapeless);
return ArraysAndExtra{
compile_replace(
entry.tape, entry.inputs, entry.outputs, inputs, shapeless),
entry.extra};
};
}
std::function<std::vector<array>(const std::vector<array>&)> compile(
std::function<std::vector<array>(const std::vector<array>&)> fun,
std::uintptr_t fun_id,
bool shapeless /* = false */,
std::vector<uint64_t> constants /* = {} */) {
if (skip_compile()) {
return fun;
}
if (!fun) {
throw std::invalid_argument(
"[compile] Cannot compile a function without a target.");
}
ArrayFnWithExtra fun_with_extra =
[fun = std::move(fun)](const std::vector<array>& inputs) {
return ArraysAndExtra{fun(inputs), nullptr};
};
auto compiled_fun = compile(
std::move(fun_with_extra), fun_id, shapeless, std::move(constants));
return [compiled_fun =
std::move(compiled_fun)](const std::vector<array>& inputs) {
return compiled_fun(inputs).first;
};
}