fix python globals bug, and erase

This commit is contained in:
Awni Hannun 2024-01-15 15:49:46 -08:00
parent 6189111494
commit b75ff47098
3 changed files with 41 additions and 11 deletions

View File

@ -70,6 +70,10 @@ struct CompilerCache {
return entries.back();
};
void erase(size_t fun_id) {
cache_.erase(fun_id);
}
private:
CompilerCache() {}
friend CompilerCache& compiler_cache();
@ -357,6 +361,11 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
return compile_replace(entry.tape, entry.inputs, entry.outputs, inputs);
};
}
void compile_erase(size_t fun_id) {
detail::compiler_cache().erase(fun_id);
}
} // namespace detail
std::function<std::vector<array>(const std::vector<array>&)> compile(

View File

@ -20,6 +20,9 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
size_t fun_id);
// Erase cached compile functions
void compile_erase(size_t fun_id);
// Create an InTracing object during tracing operations to signify to the rest
// of the codebase that we are during tracing so evals should not throw away
// the graph.

View File

@ -4,6 +4,7 @@
#include <pybind11/stl.h>
#include <algorithm>
#include <fstream>
#include <iostream> // TODO
#include <numeric>
#include <sstream>
@ -455,14 +456,16 @@ std::unordered_map<size_t, py::object>& tree_cache() {
return tree_cache_;
}
auto py_compile(const py::function& fun) {
return [fun](const py::args& args) {
struct PyCompiledFun {
py::function fun;
py::object operator()(const py::args& args) {
// TODO, awni, I think this cast is ok??
size_t fun_id = reinterpret_cast<size_t>(fun.ptr());
auto compile_fun = [fun_id, &fun, &args](const std::vector<array>& a) {
auto compile_fun = [fun_id, this, &args](const std::vector<array>& a) {
// Call the python function
py::object py_outputs = fun(*tree_unflatten(args, a));
py::object py_outputs = this->fun(*tree_unflatten(args, a));
// Flatten the outputs
auto outputs = tree_flatten(py_outputs, true);
@ -477,12 +480,20 @@ auto py_compile(const py::function& fun) {
auto inputs = tree_flatten(args, true);
// Get globally enclosed arrays so we don't compile through them
// c.f. https://github.com/python/cpython/blob/main/Lib/inspect.py#L1638
if (py::hasattr(fun, "__globals__")) {
auto global_inputs = tree_flatten(py::getattr(fun, "__globals__"), false);
std::move(
std::begin(global_inputs),
std::end(global_inputs),
std::back_inserter(inputs));
py::dict globals = py::getattr(fun, "__globals__");
auto co_names = py::getattr(py::getattr(fun, "__code__"), "co_names");
for (auto& n : co_names) {
if (py::cast<bool>(globals.attr("__contains__")(n))) {
auto global_inputs =
tree_flatten(globals.attr("__getitem__")(n), false);
std::move(
std::begin(global_inputs),
std::end(global_inputs),
std::back_inserter(inputs));
}
}
}
// Get locally enclosed arrays so we don't compile through them
@ -507,7 +518,12 @@ auto py_compile(const py::function& fun) {
py::object py_outputs = tree_cache().at(fun_id);
return tree_unflatten_none(py_outputs, outputs);
};
}
~PyCompiledFun() {
size_t fun_id = reinterpret_cast<size_t>(fun.ptr());
detail::compile_erase(fun_id);
}
};
void init_transforms(py::module_& m) {
py::options options;
@ -810,7 +826,9 @@ void init_transforms(py::module_& m) {
"file"_a);
m.def(
"compile",
[](const py::function& fun) { return py::cpp_function(py_compile(fun)); },
[](const py::function& fun) {
return py::cpp_function(PyCompiledFun{fun});
},
"fun"_a,
R"pbdoc(
compile(fun: function) -> function