mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-29 22:01:17 +08:00
fix python globals bug, and erase
This commit is contained in:
parent
6189111494
commit
b75ff47098
@ -70,6 +70,10 @@ struct CompilerCache {
|
|||||||
return entries.back();
|
return entries.back();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
void erase(size_t fun_id) {
|
||||||
|
cache_.erase(fun_id);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
CompilerCache() {}
|
CompilerCache() {}
|
||||||
friend CompilerCache& compiler_cache();
|
friend CompilerCache& compiler_cache();
|
||||||
@ -357,6 +361,11 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
|||||||
return compile_replace(entry.tape, entry.inputs, entry.outputs, inputs);
|
return compile_replace(entry.tape, entry.inputs, entry.outputs, inputs);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void compile_erase(size_t fun_id) {
|
||||||
|
detail::compiler_cache().erase(fun_id);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
|
@ -20,6 +20,9 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
|||||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
size_t fun_id);
|
size_t fun_id);
|
||||||
|
|
||||||
|
// Erase cached compile functions
|
||||||
|
void compile_erase(size_t fun_id);
|
||||||
|
|
||||||
// Create an InTracing object during tracing operations to signify to the rest
|
// Create an InTracing object during tracing operations to signify to the rest
|
||||||
// of the codebase that we are during tracing so evals should not throw away
|
// of the codebase that we are during tracing so evals should not throw away
|
||||||
// the graph.
|
// the graph.
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
#include <pybind11/stl.h>
|
#include <pybind11/stl.h>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <iostream> // TODO
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
@ -455,14 +456,16 @@ std::unordered_map<size_t, py::object>& tree_cache() {
|
|||||||
return tree_cache_;
|
return tree_cache_;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto py_compile(const py::function& fun) {
|
struct PyCompiledFun {
|
||||||
return [fun](const py::args& args) {
|
py::function fun;
|
||||||
|
|
||||||
|
py::object operator()(const py::args& args) {
|
||||||
// TODO, awni, I think this cast is ok??
|
// TODO, awni, I think this cast is ok??
|
||||||
size_t fun_id = reinterpret_cast<size_t>(fun.ptr());
|
size_t fun_id = reinterpret_cast<size_t>(fun.ptr());
|
||||||
|
|
||||||
auto compile_fun = [fun_id, &fun, &args](const std::vector<array>& a) {
|
auto compile_fun = [fun_id, this, &args](const std::vector<array>& a) {
|
||||||
// Call the python function
|
// Call the python function
|
||||||
py::object py_outputs = fun(*tree_unflatten(args, a));
|
py::object py_outputs = this->fun(*tree_unflatten(args, a));
|
||||||
|
|
||||||
// Flatten the outputs
|
// Flatten the outputs
|
||||||
auto outputs = tree_flatten(py_outputs, true);
|
auto outputs = tree_flatten(py_outputs, true);
|
||||||
@ -477,12 +480,20 @@ auto py_compile(const py::function& fun) {
|
|||||||
auto inputs = tree_flatten(args, true);
|
auto inputs = tree_flatten(args, true);
|
||||||
|
|
||||||
// Get globally enclosed arrays so we don't compile through them
|
// Get globally enclosed arrays so we don't compile through them
|
||||||
|
// c.f. https://github.com/python/cpython/blob/main/Lib/inspect.py#L1638
|
||||||
if (py::hasattr(fun, "__globals__")) {
|
if (py::hasattr(fun, "__globals__")) {
|
||||||
auto global_inputs = tree_flatten(py::getattr(fun, "__globals__"), false);
|
py::dict globals = py::getattr(fun, "__globals__");
|
||||||
std::move(
|
auto co_names = py::getattr(py::getattr(fun, "__code__"), "co_names");
|
||||||
std::begin(global_inputs),
|
for (auto& n : co_names) {
|
||||||
std::end(global_inputs),
|
if (py::cast<bool>(globals.attr("__contains__")(n))) {
|
||||||
std::back_inserter(inputs));
|
auto global_inputs =
|
||||||
|
tree_flatten(globals.attr("__getitem__")(n), false);
|
||||||
|
std::move(
|
||||||
|
std::begin(global_inputs),
|
||||||
|
std::end(global_inputs),
|
||||||
|
std::back_inserter(inputs));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get locally enclosed arrays so we don't compile through them
|
// Get locally enclosed arrays so we don't compile through them
|
||||||
@ -507,7 +518,12 @@ auto py_compile(const py::function& fun) {
|
|||||||
py::object py_outputs = tree_cache().at(fun_id);
|
py::object py_outputs = tree_cache().at(fun_id);
|
||||||
return tree_unflatten_none(py_outputs, outputs);
|
return tree_unflatten_none(py_outputs, outputs);
|
||||||
};
|
};
|
||||||
}
|
|
||||||
|
~PyCompiledFun() {
|
||||||
|
size_t fun_id = reinterpret_cast<size_t>(fun.ptr());
|
||||||
|
detail::compile_erase(fun_id);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void init_transforms(py::module_& m) {
|
void init_transforms(py::module_& m) {
|
||||||
py::options options;
|
py::options options;
|
||||||
@ -810,7 +826,9 @@ void init_transforms(py::module_& m) {
|
|||||||
"file"_a);
|
"file"_a);
|
||||||
m.def(
|
m.def(
|
||||||
"compile",
|
"compile",
|
||||||
[](const py::function& fun) { return py::cpp_function(py_compile(fun)); },
|
[](const py::function& fun) {
|
||||||
|
return py::cpp_function(PyCompiledFun{fun});
|
||||||
|
},
|
||||||
"fun"_a,
|
"fun"_a,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
compile(fun: function) -> function
|
compile(fun: function) -> function
|
||||||
|
Loading…
Reference in New Issue
Block a user