mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-08 05:18:12 +08:00
compile works with python closures
This commit is contained in:
@@ -457,14 +457,10 @@ std::unordered_map<size_t, py::object>& tree_cache() {
|
||||
|
||||
auto py_compile(const py::function& fun) {
|
||||
return [fun](const py::args& args) {
|
||||
// Inputs must be array or tree of arrays
|
||||
auto inputs = tree_flatten(args, true);
|
||||
|
||||
// TODO, awni, I think this cast is ok??
|
||||
size_t fun_id = reinterpret_cast<size_t>(fun.ptr());
|
||||
|
||||
auto compile_fun = [fun_id, &fun, &args, &inputs](
|
||||
const std::vector<array>& a) {
|
||||
auto compile_fun = [fun_id, &fun, &args](const std::vector<array>& a) {
|
||||
// Call the python function
|
||||
py::object py_outputs = fun(*tree_unflatten(args, a));
|
||||
|
||||
@@ -477,6 +473,29 @@ auto py_compile(const py::function& fun) {
|
||||
return outputs;
|
||||
};
|
||||
|
||||
// Inputs must be array or tree of arrays
|
||||
auto inputs = tree_flatten(args, true);
|
||||
|
||||
// Get globally enclosed arrays so we don't compile through them
|
||||
auto global_inputs = tree_flatten(py::getattr(fun, "__globals__"), false);
|
||||
std::move(
|
||||
std::begin(global_inputs),
|
||||
std::end(global_inputs),
|
||||
std::back_inserter(inputs));
|
||||
|
||||
// Get locally enclosed arrays so we don't compile through them
|
||||
auto closures = py::getattr(fun, "__closure__");
|
||||
if (py::isinstance<py::tuple>(closures)) {
|
||||
for (auto& closure : closures) {
|
||||
auto enclosed_inputs =
|
||||
tree_flatten(py::getattr(closure, "cell_contents"), false);
|
||||
std::move(
|
||||
std::begin(enclosed_inputs),
|
||||
std::end(enclosed_inputs),
|
||||
std::back_inserter(inputs));
|
||||
}
|
||||
}
|
||||
|
||||
// Compile and call
|
||||
auto outputs = detail::compile(compile_fun, fun_id)(inputs);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user