mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
clear output tree from cache when function goes out of scope
This commit is contained in:
parent
390e50b163
commit
db40990d33
@ -465,15 +465,12 @@ struct PyCompiledFun {
|
|||||||
PyCompiledFun(const PyCompiledFun&) = delete;
|
PyCompiledFun(const PyCompiledFun&) = delete;
|
||||||
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
|
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
|
||||||
PyCompiledFun& operator=(PyCompiledFun&& other) = delete;
|
PyCompiledFun& operator=(PyCompiledFun&& other) = delete;
|
||||||
PyCompiledFun(PyCompiledFun&& other) {
|
PyCompiledFun(PyCompiledFun&& other)
|
||||||
fun = other.fun;
|
: fun(std::move(other.fun)), fun_id(reinterpret_cast<size_t>(fun.ptr())) {
|
||||||
other.fun_id = 0;
|
other.fun_id = 0;
|
||||||
fun_id = reinterpret_cast<size_t>(fun.ptr());
|
|
||||||
};
|
};
|
||||||
|
|
||||||
py::object operator()(const py::args& args) {
|
py::object operator()(const py::args& args) {
|
||||||
// TODO, awni, I think this cast is ok??
|
|
||||||
|
|
||||||
auto compile_fun = [this, &args](const std::vector<array>& a) {
|
auto compile_fun = [this, &args](const std::vector<array>& a) {
|
||||||
// Call the python function
|
// Call the python function
|
||||||
py::object py_outputs = this->fun(*tree_unflatten(args, a));
|
py::object py_outputs = this->fun(*tree_unflatten(args, a));
|
||||||
@ -531,6 +528,7 @@ struct PyCompiledFun {
|
|||||||
};
|
};
|
||||||
|
|
||||||
~PyCompiledFun() {
|
~PyCompiledFun() {
|
||||||
|
tree_cache().erase(fun_id);
|
||||||
detail::compile_erase(fun_id);
|
detail::compile_erase(fun_id);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user