clear output tree from cache when function goes out of scope

This commit is contained in:
Awni Hannun 2024-01-17 07:01:00 -08:00
parent 390e50b163
commit db40990d33

View File

@ -465,15 +465,12 @@ struct PyCompiledFun {
PyCompiledFun(const PyCompiledFun&) = delete; PyCompiledFun(const PyCompiledFun&) = delete;
PyCompiledFun& operator=(const PyCompiledFun&) = delete; PyCompiledFun& operator=(const PyCompiledFun&) = delete;
PyCompiledFun& operator=(PyCompiledFun&& other) = delete; PyCompiledFun& operator=(PyCompiledFun&& other) = delete;
PyCompiledFun(PyCompiledFun&& other) { PyCompiledFun(PyCompiledFun&& other)
fun = other.fun; : fun(std::move(other.fun)), fun_id(reinterpret_cast<size_t>(fun.ptr())) {
other.fun_id = 0; other.fun_id = 0;
fun_id = reinterpret_cast<size_t>(fun.ptr());
}; };
py::object operator()(const py::args& args) { py::object operator()(const py::args& args) {
// TODO, awni, I think this cast is ok??
auto compile_fun = [this, &args](const std::vector<array>& a) { auto compile_fun = [this, &args](const std::vector<array>& a) {
// Call the python function // Call the python function
py::object py_outputs = this->fun(*tree_unflatten(args, a)); py::object py_outputs = this->fun(*tree_unflatten(args, a));
@ -531,6 +528,7 @@ struct PyCompiledFun {
}; };
~PyCompiledFun() { ~PyCompiledFun() {
tree_cache().erase(fun_id);
detail::compile_erase(fun_id); detail::compile_erase(fun_id);
} }
}; };