clear output tree from cache when function goes out of scope

This commit is contained in:
Awni Hannun 2024-01-17 07:01:00 -08:00
parent 390e50b163
commit db40990d33

View File

@ -465,15 +465,12 @@ struct PyCompiledFun {
PyCompiledFun(const PyCompiledFun&) = delete;
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
PyCompiledFun& operator=(PyCompiledFun&& other) = delete;
PyCompiledFun(PyCompiledFun&& other) {
fun = other.fun;
PyCompiledFun(PyCompiledFun&& other)
: fun(std::move(other.fun)), fun_id(reinterpret_cast<size_t>(fun.ptr())) {
other.fun_id = 0;
fun_id = reinterpret_cast<size_t>(fun.ptr());
};
py::object operator()(const py::args& args) {
// TODO, awni, I think this cast is ok??
auto compile_fun = [this, &args](const std::vector<array>& a) {
// Call the python function
py::object py_outputs = this->fun(*tree_unflatten(args, a));
@ -531,6 +528,7 @@ struct PyCompiledFun {
};
~PyCompiledFun() {
tree_cache().erase(fun_id);
detail::compile_erase(fun_id);
}
};