Compile now can attach arbitrary data to an entry (#2634)

This commit is contained in:
Angelos Katharopoulos
2025-09-30 13:33:27 -07:00
committed by GitHub
parent dc371ae7a5
commit eb24267b56
5 changed files with 130 additions and 31 deletions

View File

@@ -389,19 +389,22 @@ auto py_vmap(
};
}
std::unordered_map<std::uintptr_t, nb::object>& tree_cache() {
// This map is used to Cache the tree structure of the outputs
static std::unordered_map<std::uintptr_t, nb::object> tree_cache_;
return tree_cache_;
}
struct PyCompiledFun {
nb::callable fun;
std::uintptr_t fun_id;
nb::object captured_inputs;
nb::object captured_outputs;
bool shapeless;
mutable size_t num_outputs{0};
// Data to attach to the compiled function that contains the python output
// structure and the number of arrays in said structure.
struct AttachedData {
nb::object output_structure;
int num_outputs;
AttachedData(nb::object output_structure_, int num_outputs_)
: output_structure(output_structure_), num_outputs(num_outputs_) {}
};
PyCompiledFun(
const nb::callable& fun,
@@ -424,7 +427,6 @@ struct PyCompiledFun {
captured_inputs = std::move(other.captured_inputs);
captured_outputs = std::move(other.captured_outputs);
shapeless = other.shapeless;
num_outputs = other.num_outputs;
};
nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
@@ -508,9 +510,9 @@ struct PyCompiledFun {
auto [outputs, py_outputs] =
tree_flatten_with_structure(std::move(tree_outputs), false);
tree_cache().insert({fun_id, py_outputs});
std::shared_ptr<void> extra_data =
std::make_shared<AttachedData>(py_outputs, outputs.size());
num_outputs = outputs.size();
if (!captured_outputs.is_none()) {
auto flat_out_captures = tree_flatten(captured_outputs, false);
outputs.insert(
@@ -523,7 +525,7 @@ struct PyCompiledFun {
if (!captured_inputs.is_none()) {
tree_replace(captured_inputs, trace_captures, flat_in_captures);
}
return outputs;
return mx::detail::ArraysAndExtra{outputs, extra_data};
};
if (!captured_inputs.is_none()) {
@@ -535,8 +537,14 @@ struct PyCompiledFun {
}
// Compile and call
auto outputs =
auto [outputs, extra_data] =
mx::detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
int num_outputs =
reinterpret_cast<AttachedData*>(extra_data.get())->num_outputs;
nb::object py_outputs =
reinterpret_cast<AttachedData*>(extra_data.get())->output_structure;
if (!captured_outputs.is_none()) {
std::vector<mx::array> captures(
std::make_move_iterator(outputs.begin() + num_outputs),
@@ -545,8 +553,7 @@ struct PyCompiledFun {
}
// Put the outputs back in the container
nb::object py_outputs = tree_cache().at(fun_id);
return tree_unflatten_from_structure(py_outputs, outputs);
return tree_unflatten_from_structure(std::move(py_outputs), outputs);
}
nb::object operator()(const nb::args& args, const nb::kwargs& kwargs) const {
@@ -556,7 +563,6 @@ struct PyCompiledFun {
~PyCompiledFun() {
nb::gil_scoped_acquire gil;
tree_cache().erase(fun_id);
mx::detail::compile_erase(fun_id);
fun.reset();
captured_inputs.reset();
@@ -1479,8 +1485,6 @@ void init_transforms(nb::module_& m) {
// Register static Python object cleanup before the interpreter exits
auto atexit = nb::module_::import_("atexit");
atexit.attr("register")(nb::cpp_function([]() {
tree_cache().clear();
mx::detail::compile_clear_cache();
}));
atexit.attr("register")(
nb::cpp_function([]() { mx::detail::compile_clear_cache(); }));
}

View File

@@ -1064,6 +1064,57 @@ class TestCompile(mlx_tests.MLXTestCase):
out = fun(mx.array(1.0), mx.array(2.0))
self.assertEqual(out.item(), 3.0)
def test_compile_changing_outputs(self):
@mx.compile
def fun(x, y):
if y is None:
return 2 * x
elif (
isinstance(x, mx.array)
and isinstance(y, mx.array)
and x.dtype == y.dtype == mx.float32
):
return [x + y]
elif y.dtype == mx.bool_:
return {"a": x, "b": y * x}
else:
return None
a = fun(mx.array(1.0), mx.array(2.0))
self.assertTrue(isinstance(a, list))
self.assertEqual(a[0].item(), 3.0)
b = fun(mx.array(1.0), mx.array(True))
self.assertTrue(isinstance(b, dict))
self.assertEqual(b["a"].item(), 1.0)
self.assertEqual(b["b"].item(), 1.0)
c = fun(mx.array(1.0), None)
self.assertTrue(isinstance(c, mx.array))
self.assertEqual(c.item(), 2.0)
d = fun(False, mx.array(1.0))
self.assertTrue(d is None)
def test_compile_changing_outputs_with_state(self):
state = [mx.array(1.0)]
@partial(mx.compile, inputs=state, outputs=state)
def fun(y):
x = state[0]
if y.dtype == mx.float32:
state[0] = 2 * y
return [x, y, x + y]
elif y.dtype == mx.int32:
state[0] *= 2
return x + y
for i in range(10):
fun(mx.array(1.0))
fun(mx.array(1))
self.assertEqual(state[0].item(), 4)
if __name__ == "__main__":
mlx_tests.MLXTestRunner()