diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index ca1931ec3..a7233a468 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -465,15 +465,12 @@ struct PyCompiledFun { PyCompiledFun(const PyCompiledFun&) = delete; PyCompiledFun& operator=(const PyCompiledFun&) = delete; PyCompiledFun& operator=(PyCompiledFun&& other) = delete; - PyCompiledFun(PyCompiledFun&& other) { - fun = other.fun; + PyCompiledFun(PyCompiledFun&& other) + : fun(std::move(other.fun)), fun_id(reinterpret_cast(fun.ptr())) { other.fun_id = 0; - fun_id = reinterpret_cast(fun.ptr()); }; py::object operator()(const py::args& args) { - // TODO, awni, I think this cast is ok?? - auto compile_fun = [this, &args](const std::vector& a) { // Call the python function py::object py_outputs = this->fun(*tree_unflatten(args, a)); @@ -531,6 +528,7 @@ struct PyCompiledFun { }; ~PyCompiledFun() { + tree_cache().erase(fun_id); detail::compile_erase(fun_id); } };