mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compile now can attach arbitrary data to an entry (#2634)
This commit is contained in:
committed by
GitHub
parent
dc371ae7a5
commit
eb24267b56
@@ -296,6 +296,7 @@ class CompilerCache {
|
||||
std::vector<array> tape;
|
||||
bool empty{true};
|
||||
std::vector<uint64_t> constants;
|
||||
std::shared_ptr<void> extra;
|
||||
};
|
||||
|
||||
// Returns a reference to a CacheEntry which can be updated
|
||||
@@ -376,8 +377,9 @@ CompilerCache& compiler_cache() {
|
||||
return compiler_cache_;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<array>> compile_trace(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
std::tuple<std::vector<array>, std::vector<array>, std::shared_ptr<void>>
|
||||
compile_trace(
|
||||
const ArrayFnWithExtra& fun,
|
||||
const std::vector<array>& inputs,
|
||||
bool shapeless) {
|
||||
// Set the global tracing flag.
|
||||
@@ -391,7 +393,9 @@ std::pair<std::vector<array>, std::vector<array>> compile_trace(
|
||||
in.set_tracer(true);
|
||||
tracer_inputs.push_back(std::move(in));
|
||||
}
|
||||
return {tracer_inputs, fun(tracer_inputs)};
|
||||
|
||||
auto output = fun(tracer_inputs);
|
||||
return {tracer_inputs, output.first, output.second};
|
||||
}
|
||||
|
||||
// Traverses the graph to build a tape and a map of array ids to their parents
|
||||
@@ -932,8 +936,8 @@ bool skip_compile() {
|
||||
!(compile_available_for_device(default_device()));
|
||||
}
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||
ArrayFnWithExtra compile(
|
||||
ArrayFnWithExtra fun,
|
||||
std::uintptr_t fun_id,
|
||||
bool shapeless /* = false */,
|
||||
std::vector<uint64_t> constants /* = {} */) {
|
||||
@@ -966,7 +970,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
// Set the constants
|
||||
entry.constants = std::move(constants);
|
||||
// Trace to build the graph
|
||||
std::tie(entry.inputs, entry.outputs) =
|
||||
std::tie(entry.inputs, entry.outputs, entry.extra) =
|
||||
compile_trace(fun, inputs, shapeless);
|
||||
|
||||
// DFS the graph and get a tape, and a map of array id to (parent,
|
||||
@@ -991,8 +995,37 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
|
||||
// At this point we must have a tape, now replace the placeholders
|
||||
// with real arrays that can be evaluated
|
||||
return compile_replace(
|
||||
entry.tape, entry.inputs, entry.outputs, inputs, shapeless);
|
||||
return ArraysAndExtra{
|
||||
compile_replace(
|
||||
entry.tape, entry.inputs, entry.outputs, inputs, shapeless),
|
||||
entry.extra};
|
||||
};
|
||||
}
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||
std::uintptr_t fun_id,
|
||||
bool shapeless /* = false */,
|
||||
std::vector<uint64_t> constants /* = {} */) {
|
||||
if (skip_compile()) {
|
||||
return fun;
|
||||
}
|
||||
if (!fun) {
|
||||
throw std::invalid_argument(
|
||||
"[compile] Cannot compile a function without a target.");
|
||||
}
|
||||
|
||||
ArrayFnWithExtra fun_with_extra =
|
||||
[fun = std::move(fun)](const std::vector<array>& inputs) {
|
||||
return ArraysAndExtra{fun(inputs), nullptr};
|
||||
};
|
||||
|
||||
auto compiled_fun = compile(
|
||||
std::move(fun_with_extra), fun_id, shapeless, std::move(constants));
|
||||
|
||||
return [compiled_fun =
|
||||
std::move(compiled_fun)](const std::vector<array>& inputs) {
|
||||
return compiled_fun(inputs).first;
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,10 @@
|
||||
|
||||
namespace mlx::core::detail {
|
||||
|
||||
using ArraysAndExtra = std::pair<std::vector<array>, std::shared_ptr<void>>;
|
||||
using ArrayFnWithExtra =
|
||||
std::function<ArraysAndExtra(const std::vector<array>&)>;
|
||||
|
||||
// This is not part of the general C++ API as calling with a bad id is a bad
|
||||
// idea.
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
@@ -16,6 +20,12 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
bool shapeless = false,
|
||||
std::vector<uint64_t> constants = {});
|
||||
|
||||
ArrayFnWithExtra compile(
|
||||
ArrayFnWithExtra fun,
|
||||
std::uintptr_t fun_id,
|
||||
bool shapeless,
|
||||
std::vector<uint64_t> constants);
|
||||
|
||||
// Erase cached compile functions
|
||||
void compile_erase(std::uintptr_t fun_id);
|
||||
|
||||
@@ -25,8 +35,9 @@ void compile_clear_cache();
|
||||
|
||||
bool compile_available_for_device(const Device& device);
|
||||
|
||||
std::pair<std::vector<array>, std::vector<array>> compile_trace(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
std::tuple<std::vector<array>, std::vector<array>, std::shared_ptr<void>>
|
||||
compile_trace(
|
||||
const ArrayFnWithExtra& fun,
|
||||
const std::vector<array>& inputs,
|
||||
bool shapeless);
|
||||
|
||||
|
||||
@@ -579,11 +579,11 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
||||
for (auto& k : kwarg_keys) {
|
||||
kwargs.insert({k, *it++});
|
||||
}
|
||||
return fun(args, kwargs);
|
||||
return detail::ArraysAndExtra{fun(args, kwargs), nullptr};
|
||||
};
|
||||
|
||||
// Trace to build the graph
|
||||
auto [trace_inputs, trace_outputs] =
|
||||
auto [trace_inputs, trace_outputs, extra] =
|
||||
detail::compile_trace(flat_fun, inputs, ftable->shapeless);
|
||||
|
||||
// DFS the graph and get the tape
|
||||
|
||||
Reference in New Issue
Block a user