mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-28 21:21:21 +08:00
fix python globals bug, and erase
This commit is contained in:
parent
6189111494
commit
b75ff47098
@ -70,6 +70,10 @@ struct CompilerCache {
|
||||
return entries.back();
|
||||
};
|
||||
|
||||
void erase(size_t fun_id) {
|
||||
cache_.erase(fun_id);
|
||||
}
|
||||
|
||||
private:
|
||||
CompilerCache() {}
|
||||
friend CompilerCache& compiler_cache();
|
||||
@ -357,6 +361,11 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
return compile_replace(entry.tape, entry.inputs, entry.outputs, inputs);
|
||||
};
|
||||
}
|
||||
|
||||
void compile_erase(size_t fun_id) {
|
||||
detail::compiler_cache().erase(fun_id);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
|
@ -20,6 +20,9 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
size_t fun_id);
|
||||
|
||||
// Erase cached compile functions
|
||||
void compile_erase(size_t fun_id);
|
||||
|
||||
// Create an InTracing object during tracing operations to signify to the rest
|
||||
// of the codebase that we are during tracing so evals should not throw away
|
||||
// the graph.
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <pybind11/stl.h>
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <iostream> // TODO
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
@ -455,14 +456,16 @@ std::unordered_map<size_t, py::object>& tree_cache() {
|
||||
return tree_cache_;
|
||||
}
|
||||
|
||||
auto py_compile(const py::function& fun) {
|
||||
return [fun](const py::args& args) {
|
||||
struct PyCompiledFun {
|
||||
py::function fun;
|
||||
|
||||
py::object operator()(const py::args& args) {
|
||||
// TODO, awni, I think this cast is ok??
|
||||
size_t fun_id = reinterpret_cast<size_t>(fun.ptr());
|
||||
|
||||
auto compile_fun = [fun_id, &fun, &args](const std::vector<array>& a) {
|
||||
auto compile_fun = [fun_id, this, &args](const std::vector<array>& a) {
|
||||
// Call the python function
|
||||
py::object py_outputs = fun(*tree_unflatten(args, a));
|
||||
py::object py_outputs = this->fun(*tree_unflatten(args, a));
|
||||
|
||||
// Flatten the outputs
|
||||
auto outputs = tree_flatten(py_outputs, true);
|
||||
@ -477,12 +480,20 @@ auto py_compile(const py::function& fun) {
|
||||
auto inputs = tree_flatten(args, true);
|
||||
|
||||
// Get globally enclosed arrays so we don't compile through them
|
||||
// c.f. https://github.com/python/cpython/blob/main/Lib/inspect.py#L1638
|
||||
if (py::hasattr(fun, "__globals__")) {
|
||||
auto global_inputs = tree_flatten(py::getattr(fun, "__globals__"), false);
|
||||
std::move(
|
||||
std::begin(global_inputs),
|
||||
std::end(global_inputs),
|
||||
std::back_inserter(inputs));
|
||||
py::dict globals = py::getattr(fun, "__globals__");
|
||||
auto co_names = py::getattr(py::getattr(fun, "__code__"), "co_names");
|
||||
for (auto& n : co_names) {
|
||||
if (py::cast<bool>(globals.attr("__contains__")(n))) {
|
||||
auto global_inputs =
|
||||
tree_flatten(globals.attr("__getitem__")(n), false);
|
||||
std::move(
|
||||
std::begin(global_inputs),
|
||||
std::end(global_inputs),
|
||||
std::back_inserter(inputs));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get locally enclosed arrays so we don't compile through them
|
||||
@ -507,7 +518,12 @@ auto py_compile(const py::function& fun) {
|
||||
py::object py_outputs = tree_cache().at(fun_id);
|
||||
return tree_unflatten_none(py_outputs, outputs);
|
||||
};
|
||||
}
|
||||
|
||||
~PyCompiledFun() {
|
||||
size_t fun_id = reinterpret_cast<size_t>(fun.ptr());
|
||||
detail::compile_erase(fun_id);
|
||||
}
|
||||
};
|
||||
|
||||
void init_transforms(py::module_& m) {
|
||||
py::options options;
|
||||
@ -810,7 +826,9 @@ void init_transforms(py::module_& m) {
|
||||
"file"_a);
|
||||
m.def(
|
||||
"compile",
|
||||
[](const py::function& fun) { return py::cpp_function(py_compile(fun)); },
|
||||
[](const py::function& fun) {
|
||||
return py::cpp_function(PyCompiledFun{fun});
|
||||
},
|
||||
"fun"_a,
|
||||
R"pbdoc(
|
||||
compile(fun: function) -> function
|
||||
|
Loading…
Reference in New Issue
Block a user