Cycle leak break (#1856)

* detect and break leaks in custom function

* detect and break leaks in custom function
This commit is contained in:
Awni Hannun
2025-02-11 14:45:02 -08:00
committed by GitHub
parent 142b77751d
commit 2a45056ba8
10 changed files with 396 additions and 49 deletions

View File

@@ -19,6 +19,7 @@
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
#include "mlx/utils.h"
#include "python/src/mlx_func.h"
#include "python/src/trees.h"
namespace mx = mlx::core;
@@ -543,9 +544,9 @@ struct PyCompiledFun {
tree_cache().erase(fun_id);
mx::detail::compile_erase(fun_id);
fun.release().dec_ref();
captured_inputs.release().dec_ref();
captured_outputs.release().dec_ref();
fun.reset();
captured_inputs.reset();
captured_outputs.reset();
}
};
@@ -555,7 +556,7 @@ class PyCheckpointedFun {
~PyCheckpointedFun() {
nb::gil_scoped_acquire gil;
fun_.release().dec_ref();
fun_.reset();
}
struct InnerFunction {
@@ -573,8 +574,8 @@ class PyCheckpointedFun {
~InnerFunction() {
nb::gil_scoped_acquire gil;
fun_.release().dec_ref();
args_structure_.release().dec_ref();
fun_.reset();
args_structure_.reset();
}
std::vector<mx::array> operator()(const std::vector<mx::array>& inputs) {
@@ -609,6 +610,10 @@ class PyCheckpointedFun {
nb::callable fun_;
};
int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg);
int py_custom_function_tp_clear(PyObject* self);
/**
* PyCustomFunction is the class that implements the python decorator
* `mx.custom_function`.
@@ -641,17 +646,7 @@ class PyCustomFunction {
PyCustomFunction(nb::callable fun) : fun_(std::move(fun)) {}
~PyCustomFunction() {
nb::gil_scoped_acquire gil;
fun_.release().dec_ref();
if (vjp_fun_.has_value()) {
(*vjp_fun_).release().dec_ref();
}
if (jvp_fun_.has_value()) {
(*jvp_fun_).release().dec_ref();
}
if (vmap_fun_.has_value()) {
(*vmap_fun_).release().dec_ref();
}
reset();
}
struct InnerFunction {
@@ -669,10 +664,10 @@ class PyCustomFunction {
~InnerFunction() {
nb::gil_scoped_acquire gil;
fun_.release().dec_ref();
input_structure_.release().dec_ref();
fun_.reset();
input_structure_.reset();
if (output_structure_.use_count() == 1) {
output_structure_->release().dec_ref();
output_structure_->reset();
}
}
@@ -703,10 +698,10 @@ class PyCustomFunction {
~InnerVJPFunction() {
nb::gil_scoped_acquire gil;
vjp_fun_.release().dec_ref();
input_structure_.release().dec_ref();
vjp_fun_.reset();
input_structure_.reset();
if (output_structure_.use_count() == 1) {
output_structure_->release().dec_ref();
output_structure_->reset();
}
}
@@ -746,8 +741,8 @@ class PyCustomFunction {
~InnerJVPFunction() {
nb::gil_scoped_acquire gil;
jvp_fun_.release().dec_ref();
input_structure_.release().dec_ref();
jvp_fun_.reset();
input_structure_.reset();
}
std::vector<mx::array> operator()(
@@ -801,8 +796,8 @@ class PyCustomFunction {
~InnerVmapFunction() {
nb::gil_scoped_acquire gil;
vmap_fun_.release().dec_ref();
input_structure_.release().dec_ref();
vmap_fun_.reset();
input_structure_.reset();
}
std::pair<std::vector<mx::array>, std::vector<int>> operator()(
@@ -904,6 +899,20 @@ class PyCustomFunction {
vmap_fun_ = vmap_fun;
return *this;
}
void reset() {
fun_.reset();
if (vjp_fun_.has_value()) {
(*vjp_fun_).reset();
}
if (jvp_fun_.has_value()) {
(*jvp_fun_).reset();
}
if (vmap_fun_.has_value()) {
(*vmap_fun_).reset();
}
}
friend int py_custom_function_tp_traverse(PyObject*, visitproc, void*);
private:
std::optional<InnerVJPFunction> make_vjp_function(
@@ -940,10 +949,40 @@ class PyCustomFunction {
std::optional<nb::callable> vmap_fun_;
};
int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg) {
auto* p = nb::inst_ptr<PyCustomFunction>(self);
nb::handle v = nb::find(p->fun_);
Py_VISIT(v.ptr());
if (p->vjp_fun_.has_value()) {
nb::handle v = nb::find(*(p->vjp_fun_));
Py_VISIT(v.ptr());
}
if (p->jvp_fun_.has_value()) {
nb::handle v = nb::find(*(p->jvp_fun_));
Py_VISIT(v.ptr());
}
if (p->vmap_fun_.has_value()) {
nb::handle v = nb::find(*(p->vmap_fun_));
Py_VISIT(v.ptr());
}
Py_VISIT(Py_TYPE(self));
return 0;
}
int py_custom_function_tp_clear(PyObject* self) {
auto* p = nb::inst_ptr<PyCustomFunction>(self);
p->reset();
return 0;
}
PyType_Slot py_custom_function_slots[] = {
{Py_tp_traverse, (void*)py_custom_function_tp_traverse},
{Py_tp_clear, (void*)py_custom_function_tp_clear},
{0, 0}};
void init_transforms(nb::module_& m) {
nb::class_<PyCustomFunction>(
m,
"custom_function",
nb::type_slots(py_custom_function_slots),
R"pbdoc(
Set up a function for custom gradient and vmap definitions.
@@ -1224,8 +1263,10 @@ void init_transforms(nb::module_& m) {
const StrOrSet& argnames) {
auto [argnums_vec, argnames_set] =
validate_argnums_argnames(argnums, argnames);
return nb::cpp_function(py_value_and_grad(
fun, argnums_vec, argnames_set, "[value_and_grad]", false));
return mlx_func(
py_value_and_grad(
fun, argnums_vec, argnames_set, "[value_and_grad]", false),
fun);
},
"fun"_a,
"argnums"_a = nb::none(),
@@ -1290,9 +1331,11 @@ void init_transforms(nb::module_& m) {
validate_argnums_argnames(argnums, argnames);
auto fn =
py_value_and_grad(fun, argnums_vec, argnames_set, "[grad]", true);
return nb::cpp_function([fn](nb::args& args, nb::kwargs& kwargs) {
return fn(args, kwargs).second;
});
return mlx_func(
[fn = std::move(fn)](nb::args& args, nb::kwargs& kwargs) {
return fn(args, kwargs).second;
},
fun);
},
"fun"_a,
"argnums"_a = nb::none(),
@@ -1324,7 +1367,8 @@ void init_transforms(nb::module_& m) {
[](const nb::callable& fun,
const nb::object& in_axes,
const nb::object& out_axes) {
return nb::cpp_function(py_vmap(fun, in_axes, out_axes));
return mlx_func(
py_vmap(fun, in_axes, out_axes), fun, in_axes, out_axes);
},
"fun"_a,
"in_axes"_a = 0,
@@ -1379,11 +1423,15 @@ void init_transforms(nb::module_& m) {
d.is_none() ? "MLX compiled function." : nb::cast<std::string>(d);
auto sig_str = sig.str();
return nb::cpp_function(
PyCompiledFun{fun, inputs, outputs, shapeless},
nb::name(name.c_str()),
nb::sig(sig_str.c_str()),
doc.c_str());
return mlx_func(
nb::cpp_function(
PyCompiledFun{fun, inputs, outputs, shapeless},
nb::name(name.c_str()),
nb::sig(sig_str.c_str()),
doc.c_str()),
fun,
inputs,
outputs);
},
"fun"_a,
"inputs"_a = nb::none(),
@@ -1435,7 +1483,7 @@ void init_transforms(nb::module_& m) {
)pbdoc");
m.def(
"checkpoint",
[](nb::callable fun) { return nb::cpp_function(PyCheckpointedFun{fun}); },
[](nb::callable fun) { return mlx_func(PyCheckpointedFun{fun}, fun); },
"fun"_a);
// Register static Python object cleanup before the interpreter exits