mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Compile now can attach arbitrary data to an entry (#2634)
This commit is contained in:

committed by
GitHub

parent
dc371ae7a5
commit
eb24267b56
@@ -296,6 +296,7 @@ class CompilerCache {
|
|||||||
std::vector<array> tape;
|
std::vector<array> tape;
|
||||||
bool empty{true};
|
bool empty{true};
|
||||||
std::vector<uint64_t> constants;
|
std::vector<uint64_t> constants;
|
||||||
|
std::shared_ptr<void> extra;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Returns a reference to a CacheEntry which can be updated
|
// Returns a reference to a CacheEntry which can be updated
|
||||||
@@ -376,8 +377,9 @@ CompilerCache& compiler_cache() {
|
|||||||
return compiler_cache_;
|
return compiler_cache_;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<array>> compile_trace(
|
std::tuple<std::vector<array>, std::vector<array>, std::shared_ptr<void>>
|
||||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
compile_trace(
|
||||||
|
const ArrayFnWithExtra& fun,
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
bool shapeless) {
|
bool shapeless) {
|
||||||
// Set the global tracing flag.
|
// Set the global tracing flag.
|
||||||
@@ -391,7 +393,9 @@ std::pair<std::vector<array>, std::vector<array>> compile_trace(
|
|||||||
in.set_tracer(true);
|
in.set_tracer(true);
|
||||||
tracer_inputs.push_back(std::move(in));
|
tracer_inputs.push_back(std::move(in));
|
||||||
}
|
}
|
||||||
return {tracer_inputs, fun(tracer_inputs)};
|
|
||||||
|
auto output = fun(tracer_inputs);
|
||||||
|
return {tracer_inputs, output.first, output.second};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Traverses the graph to build a tape and a map of array ids to their parents
|
// Traverses the graph to build a tape and a map of array ids to their parents
|
||||||
@@ -932,8 +936,8 @@ bool skip_compile() {
|
|||||||
!(compile_available_for_device(default_device()));
|
!(compile_available_for_device(default_device()));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
ArrayFnWithExtra compile(
|
||||||
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
ArrayFnWithExtra fun,
|
||||||
std::uintptr_t fun_id,
|
std::uintptr_t fun_id,
|
||||||
bool shapeless /* = false */,
|
bool shapeless /* = false */,
|
||||||
std::vector<uint64_t> constants /* = {} */) {
|
std::vector<uint64_t> constants /* = {} */) {
|
||||||
@@ -966,7 +970,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
|||||||
// Set the constants
|
// Set the constants
|
||||||
entry.constants = std::move(constants);
|
entry.constants = std::move(constants);
|
||||||
// Trace to build the graph
|
// Trace to build the graph
|
||||||
std::tie(entry.inputs, entry.outputs) =
|
std::tie(entry.inputs, entry.outputs, entry.extra) =
|
||||||
compile_trace(fun, inputs, shapeless);
|
compile_trace(fun, inputs, shapeless);
|
||||||
|
|
||||||
// DFS the graph and get a tape, and a map of array id to (parent,
|
// DFS the graph and get a tape, and a map of array id to (parent,
|
||||||
@@ -991,8 +995,37 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
|||||||
|
|
||||||
// At this point we must have a tape, now replace the placeholders
|
// At this point we must have a tape, now replace the placeholders
|
||||||
// with real arrays that can be evaluated
|
// with real arrays that can be evaluated
|
||||||
return compile_replace(
|
return ArraysAndExtra{
|
||||||
entry.tape, entry.inputs, entry.outputs, inputs, shapeless);
|
compile_replace(
|
||||||
|
entry.tape, entry.inputs, entry.outputs, inputs, shapeless),
|
||||||
|
entry.extra};
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
|
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||||
|
std::uintptr_t fun_id,
|
||||||
|
bool shapeless /* = false */,
|
||||||
|
std::vector<uint64_t> constants /* = {} */) {
|
||||||
|
if (skip_compile()) {
|
||||||
|
return fun;
|
||||||
|
}
|
||||||
|
if (!fun) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[compile] Cannot compile a function without a target.");
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayFnWithExtra fun_with_extra =
|
||||||
|
[fun = std::move(fun)](const std::vector<array>& inputs) {
|
||||||
|
return ArraysAndExtra{fun(inputs), nullptr};
|
||||||
|
};
|
||||||
|
|
||||||
|
auto compiled_fun = compile(
|
||||||
|
std::move(fun_with_extra), fun_id, shapeless, std::move(constants));
|
||||||
|
|
||||||
|
return [compiled_fun =
|
||||||
|
std::move(compiled_fun)](const std::vector<array>& inputs) {
|
||||||
|
return compiled_fun(inputs).first;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -8,6 +8,10 @@
|
|||||||
|
|
||||||
namespace mlx::core::detail {
|
namespace mlx::core::detail {
|
||||||
|
|
||||||
|
using ArraysAndExtra = std::pair<std::vector<array>, std::shared_ptr<void>>;
|
||||||
|
using ArrayFnWithExtra =
|
||||||
|
std::function<ArraysAndExtra(const std::vector<array>&)>;
|
||||||
|
|
||||||
// This is not part of the general C++ API as calling with a bad id is a bad
|
// This is not part of the general C++ API as calling with a bad id is a bad
|
||||||
// idea.
|
// idea.
|
||||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
@@ -16,6 +20,12 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
|||||||
bool shapeless = false,
|
bool shapeless = false,
|
||||||
std::vector<uint64_t> constants = {});
|
std::vector<uint64_t> constants = {});
|
||||||
|
|
||||||
|
ArrayFnWithExtra compile(
|
||||||
|
ArrayFnWithExtra fun,
|
||||||
|
std::uintptr_t fun_id,
|
||||||
|
bool shapeless,
|
||||||
|
std::vector<uint64_t> constants);
|
||||||
|
|
||||||
// Erase cached compile functions
|
// Erase cached compile functions
|
||||||
void compile_erase(std::uintptr_t fun_id);
|
void compile_erase(std::uintptr_t fun_id);
|
||||||
|
|
||||||
@@ -25,8 +35,9 @@ void compile_clear_cache();
|
|||||||
|
|
||||||
bool compile_available_for_device(const Device& device);
|
bool compile_available_for_device(const Device& device);
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<array>> compile_trace(
|
std::tuple<std::vector<array>, std::vector<array>, std::shared_ptr<void>>
|
||||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
compile_trace(
|
||||||
|
const ArrayFnWithExtra& fun,
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
bool shapeless);
|
bool shapeless);
|
||||||
|
|
||||||
|
@@ -579,11 +579,11 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
|||||||
for (auto& k : kwarg_keys) {
|
for (auto& k : kwarg_keys) {
|
||||||
kwargs.insert({k, *it++});
|
kwargs.insert({k, *it++});
|
||||||
}
|
}
|
||||||
return fun(args, kwargs);
|
return detail::ArraysAndExtra{fun(args, kwargs), nullptr};
|
||||||
};
|
};
|
||||||
|
|
||||||
// Trace to build the graph
|
// Trace to build the graph
|
||||||
auto [trace_inputs, trace_outputs] =
|
auto [trace_inputs, trace_outputs, extra] =
|
||||||
detail::compile_trace(flat_fun, inputs, ftable->shapeless);
|
detail::compile_trace(flat_fun, inputs, ftable->shapeless);
|
||||||
|
|
||||||
// DFS the graph and get the tape
|
// DFS the graph and get the tape
|
||||||
|
@@ -389,19 +389,22 @@ auto py_vmap(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unordered_map<std::uintptr_t, nb::object>& tree_cache() {
|
|
||||||
// This map is used to Cache the tree structure of the outputs
|
|
||||||
static std::unordered_map<std::uintptr_t, nb::object> tree_cache_;
|
|
||||||
return tree_cache_;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct PyCompiledFun {
|
struct PyCompiledFun {
|
||||||
nb::callable fun;
|
nb::callable fun;
|
||||||
std::uintptr_t fun_id;
|
std::uintptr_t fun_id;
|
||||||
nb::object captured_inputs;
|
nb::object captured_inputs;
|
||||||
nb::object captured_outputs;
|
nb::object captured_outputs;
|
||||||
bool shapeless;
|
bool shapeless;
|
||||||
mutable size_t num_outputs{0};
|
|
||||||
|
// Data to attach to the compiled function that contains the python output
|
||||||
|
// structure and the number of arrays in said structure.
|
||||||
|
struct AttachedData {
|
||||||
|
nb::object output_structure;
|
||||||
|
int num_outputs;
|
||||||
|
|
||||||
|
AttachedData(nb::object output_structure_, int num_outputs_)
|
||||||
|
: output_structure(output_structure_), num_outputs(num_outputs_) {}
|
||||||
|
};
|
||||||
|
|
||||||
PyCompiledFun(
|
PyCompiledFun(
|
||||||
const nb::callable& fun,
|
const nb::callable& fun,
|
||||||
@@ -424,7 +427,6 @@ struct PyCompiledFun {
|
|||||||
captured_inputs = std::move(other.captured_inputs);
|
captured_inputs = std::move(other.captured_inputs);
|
||||||
captured_outputs = std::move(other.captured_outputs);
|
captured_outputs = std::move(other.captured_outputs);
|
||||||
shapeless = other.shapeless;
|
shapeless = other.shapeless;
|
||||||
num_outputs = other.num_outputs;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
|
nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
|
||||||
@@ -508,9 +510,9 @@ struct PyCompiledFun {
|
|||||||
auto [outputs, py_outputs] =
|
auto [outputs, py_outputs] =
|
||||||
tree_flatten_with_structure(std::move(tree_outputs), false);
|
tree_flatten_with_structure(std::move(tree_outputs), false);
|
||||||
|
|
||||||
tree_cache().insert({fun_id, py_outputs});
|
std::shared_ptr<void> extra_data =
|
||||||
|
std::make_shared<AttachedData>(py_outputs, outputs.size());
|
||||||
|
|
||||||
num_outputs = outputs.size();
|
|
||||||
if (!captured_outputs.is_none()) {
|
if (!captured_outputs.is_none()) {
|
||||||
auto flat_out_captures = tree_flatten(captured_outputs, false);
|
auto flat_out_captures = tree_flatten(captured_outputs, false);
|
||||||
outputs.insert(
|
outputs.insert(
|
||||||
@@ -523,7 +525,7 @@ struct PyCompiledFun {
|
|||||||
if (!captured_inputs.is_none()) {
|
if (!captured_inputs.is_none()) {
|
||||||
tree_replace(captured_inputs, trace_captures, flat_in_captures);
|
tree_replace(captured_inputs, trace_captures, flat_in_captures);
|
||||||
}
|
}
|
||||||
return outputs;
|
return mx::detail::ArraysAndExtra{outputs, extra_data};
|
||||||
};
|
};
|
||||||
|
|
||||||
if (!captured_inputs.is_none()) {
|
if (!captured_inputs.is_none()) {
|
||||||
@@ -535,8 +537,14 @@ struct PyCompiledFun {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Compile and call
|
// Compile and call
|
||||||
auto outputs =
|
auto [outputs, extra_data] =
|
||||||
mx::detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
|
mx::detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
|
||||||
|
|
||||||
|
int num_outputs =
|
||||||
|
reinterpret_cast<AttachedData*>(extra_data.get())->num_outputs;
|
||||||
|
nb::object py_outputs =
|
||||||
|
reinterpret_cast<AttachedData*>(extra_data.get())->output_structure;
|
||||||
|
|
||||||
if (!captured_outputs.is_none()) {
|
if (!captured_outputs.is_none()) {
|
||||||
std::vector<mx::array> captures(
|
std::vector<mx::array> captures(
|
||||||
std::make_move_iterator(outputs.begin() + num_outputs),
|
std::make_move_iterator(outputs.begin() + num_outputs),
|
||||||
@@ -545,8 +553,7 @@ struct PyCompiledFun {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Put the outputs back in the container
|
// Put the outputs back in the container
|
||||||
nb::object py_outputs = tree_cache().at(fun_id);
|
return tree_unflatten_from_structure(std::move(py_outputs), outputs);
|
||||||
return tree_unflatten_from_structure(py_outputs, outputs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
nb::object operator()(const nb::args& args, const nb::kwargs& kwargs) const {
|
nb::object operator()(const nb::args& args, const nb::kwargs& kwargs) const {
|
||||||
@@ -556,7 +563,6 @@ struct PyCompiledFun {
|
|||||||
~PyCompiledFun() {
|
~PyCompiledFun() {
|
||||||
nb::gil_scoped_acquire gil;
|
nb::gil_scoped_acquire gil;
|
||||||
|
|
||||||
tree_cache().erase(fun_id);
|
|
||||||
mx::detail::compile_erase(fun_id);
|
mx::detail::compile_erase(fun_id);
|
||||||
fun.reset();
|
fun.reset();
|
||||||
captured_inputs.reset();
|
captured_inputs.reset();
|
||||||
@@ -1479,8 +1485,6 @@ void init_transforms(nb::module_& m) {
|
|||||||
|
|
||||||
// Register static Python object cleanup before the interpreter exits
|
// Register static Python object cleanup before the interpreter exits
|
||||||
auto atexit = nb::module_::import_("atexit");
|
auto atexit = nb::module_::import_("atexit");
|
||||||
atexit.attr("register")(nb::cpp_function([]() {
|
atexit.attr("register")(
|
||||||
tree_cache().clear();
|
nb::cpp_function([]() { mx::detail::compile_clear_cache(); }));
|
||||||
mx::detail::compile_clear_cache();
|
|
||||||
}));
|
|
||||||
}
|
}
|
||||||
|
@@ -1064,6 +1064,57 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
out = fun(mx.array(1.0), mx.array(2.0))
|
out = fun(mx.array(1.0), mx.array(2.0))
|
||||||
self.assertEqual(out.item(), 3.0)
|
self.assertEqual(out.item(), 3.0)
|
||||||
|
|
||||||
|
def test_compile_changing_outputs(self):
|
||||||
|
@mx.compile
|
||||||
|
def fun(x, y):
|
||||||
|
if y is None:
|
||||||
|
return 2 * x
|
||||||
|
elif (
|
||||||
|
isinstance(x, mx.array)
|
||||||
|
and isinstance(y, mx.array)
|
||||||
|
and x.dtype == y.dtype == mx.float32
|
||||||
|
):
|
||||||
|
return [x + y]
|
||||||
|
elif y.dtype == mx.bool_:
|
||||||
|
return {"a": x, "b": y * x}
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
a = fun(mx.array(1.0), mx.array(2.0))
|
||||||
|
self.assertTrue(isinstance(a, list))
|
||||||
|
self.assertEqual(a[0].item(), 3.0)
|
||||||
|
|
||||||
|
b = fun(mx.array(1.0), mx.array(True))
|
||||||
|
self.assertTrue(isinstance(b, dict))
|
||||||
|
self.assertEqual(b["a"].item(), 1.0)
|
||||||
|
self.assertEqual(b["b"].item(), 1.0)
|
||||||
|
|
||||||
|
c = fun(mx.array(1.0), None)
|
||||||
|
self.assertTrue(isinstance(c, mx.array))
|
||||||
|
self.assertEqual(c.item(), 2.0)
|
||||||
|
|
||||||
|
d = fun(False, mx.array(1.0))
|
||||||
|
self.assertTrue(d is None)
|
||||||
|
|
||||||
|
def test_compile_changing_outputs_with_state(self):
|
||||||
|
state = [mx.array(1.0)]
|
||||||
|
|
||||||
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
|
def fun(y):
|
||||||
|
x = state[0]
|
||||||
|
if y.dtype == mx.float32:
|
||||||
|
state[0] = 2 * y
|
||||||
|
return [x, y, x + y]
|
||||||
|
elif y.dtype == mx.int32:
|
||||||
|
state[0] *= 2
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
fun(mx.array(1.0))
|
||||||
|
fun(mx.array(1))
|
||||||
|
|
||||||
|
self.assertEqual(state[0].item(), 4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mlx_tests.MLXTestRunner()
|
mlx_tests.MLXTestRunner()
|
||||||
|
Reference in New Issue
Block a user