mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-14 17:12:49 +08:00
Cycle leak break (#1856)
* detect and break leaks in custom function * detect and break leaks in custom function
This commit is contained in:
@@ -19,6 +19,7 @@
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
#include "mlx/utils.h"
|
||||
#include "python/src/mlx_func.h"
|
||||
#include "python/src/trees.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
@@ -543,9 +544,9 @@ struct PyCompiledFun {
|
||||
|
||||
tree_cache().erase(fun_id);
|
||||
mx::detail::compile_erase(fun_id);
|
||||
fun.release().dec_ref();
|
||||
captured_inputs.release().dec_ref();
|
||||
captured_outputs.release().dec_ref();
|
||||
fun.reset();
|
||||
captured_inputs.reset();
|
||||
captured_outputs.reset();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -555,7 +556,7 @@ class PyCheckpointedFun {
|
||||
~PyCheckpointedFun() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
fun_.release().dec_ref();
|
||||
fun_.reset();
|
||||
}
|
||||
|
||||
struct InnerFunction {
|
||||
@@ -573,8 +574,8 @@ class PyCheckpointedFun {
|
||||
~InnerFunction() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
fun_.release().dec_ref();
|
||||
args_structure_.release().dec_ref();
|
||||
fun_.reset();
|
||||
args_structure_.reset();
|
||||
}
|
||||
|
||||
std::vector<mx::array> operator()(const std::vector<mx::array>& inputs) {
|
||||
@@ -609,6 +610,10 @@ class PyCheckpointedFun {
|
||||
nb::callable fun_;
|
||||
};
|
||||
|
||||
int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg);
|
||||
|
||||
int py_custom_function_tp_clear(PyObject* self);
|
||||
|
||||
/**
|
||||
* PyCustomFunction is the class that implements the python decorator
|
||||
* `mx.custom_function`.
|
||||
@@ -641,17 +646,7 @@ class PyCustomFunction {
|
||||
PyCustomFunction(nb::callable fun) : fun_(std::move(fun)) {}
|
||||
~PyCustomFunction() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
fun_.release().dec_ref();
|
||||
if (vjp_fun_.has_value()) {
|
||||
(*vjp_fun_).release().dec_ref();
|
||||
}
|
||||
if (jvp_fun_.has_value()) {
|
||||
(*jvp_fun_).release().dec_ref();
|
||||
}
|
||||
if (vmap_fun_.has_value()) {
|
||||
(*vmap_fun_).release().dec_ref();
|
||||
}
|
||||
reset();
|
||||
}
|
||||
|
||||
struct InnerFunction {
|
||||
@@ -669,10 +664,10 @@ class PyCustomFunction {
|
||||
~InnerFunction() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
fun_.release().dec_ref();
|
||||
input_structure_.release().dec_ref();
|
||||
fun_.reset();
|
||||
input_structure_.reset();
|
||||
if (output_structure_.use_count() == 1) {
|
||||
output_structure_->release().dec_ref();
|
||||
output_structure_->reset();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -703,10 +698,10 @@ class PyCustomFunction {
|
||||
~InnerVJPFunction() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
vjp_fun_.release().dec_ref();
|
||||
input_structure_.release().dec_ref();
|
||||
vjp_fun_.reset();
|
||||
input_structure_.reset();
|
||||
if (output_structure_.use_count() == 1) {
|
||||
output_structure_->release().dec_ref();
|
||||
output_structure_->reset();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -746,8 +741,8 @@ class PyCustomFunction {
|
||||
~InnerJVPFunction() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
jvp_fun_.release().dec_ref();
|
||||
input_structure_.release().dec_ref();
|
||||
jvp_fun_.reset();
|
||||
input_structure_.reset();
|
||||
}
|
||||
|
||||
std::vector<mx::array> operator()(
|
||||
@@ -801,8 +796,8 @@ class PyCustomFunction {
|
||||
~InnerVmapFunction() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
vmap_fun_.release().dec_ref();
|
||||
input_structure_.release().dec_ref();
|
||||
vmap_fun_.reset();
|
||||
input_structure_.reset();
|
||||
}
|
||||
|
||||
std::pair<std::vector<mx::array>, std::vector<int>> operator()(
|
||||
@@ -904,6 +899,20 @@ class PyCustomFunction {
|
||||
vmap_fun_ = vmap_fun;
|
||||
return *this;
|
||||
}
|
||||
void reset() {
|
||||
fun_.reset();
|
||||
if (vjp_fun_.has_value()) {
|
||||
(*vjp_fun_).reset();
|
||||
}
|
||||
if (jvp_fun_.has_value()) {
|
||||
(*jvp_fun_).reset();
|
||||
}
|
||||
if (vmap_fun_.has_value()) {
|
||||
(*vmap_fun_).reset();
|
||||
}
|
||||
}
|
||||
|
||||
friend int py_custom_function_tp_traverse(PyObject*, visitproc, void*);
|
||||
|
||||
private:
|
||||
std::optional<InnerVJPFunction> make_vjp_function(
|
||||
@@ -940,10 +949,40 @@ class PyCustomFunction {
|
||||
std::optional<nb::callable> vmap_fun_;
|
||||
};
|
||||
|
||||
int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg) {
|
||||
auto* p = nb::inst_ptr<PyCustomFunction>(self);
|
||||
nb::handle v = nb::find(p->fun_);
|
||||
Py_VISIT(v.ptr());
|
||||
if (p->vjp_fun_.has_value()) {
|
||||
nb::handle v = nb::find(*(p->vjp_fun_));
|
||||
Py_VISIT(v.ptr());
|
||||
}
|
||||
if (p->jvp_fun_.has_value()) {
|
||||
nb::handle v = nb::find(*(p->jvp_fun_));
|
||||
Py_VISIT(v.ptr());
|
||||
}
|
||||
if (p->vmap_fun_.has_value()) {
|
||||
nb::handle v = nb::find(*(p->vmap_fun_));
|
||||
Py_VISIT(v.ptr());
|
||||
}
|
||||
Py_VISIT(Py_TYPE(self));
|
||||
return 0;
|
||||
}
|
||||
int py_custom_function_tp_clear(PyObject* self) {
|
||||
auto* p = nb::inst_ptr<PyCustomFunction>(self);
|
||||
p->reset();
|
||||
return 0;
|
||||
}
|
||||
PyType_Slot py_custom_function_slots[] = {
|
||||
{Py_tp_traverse, (void*)py_custom_function_tp_traverse},
|
||||
{Py_tp_clear, (void*)py_custom_function_tp_clear},
|
||||
{0, 0}};
|
||||
|
||||
void init_transforms(nb::module_& m) {
|
||||
nb::class_<PyCustomFunction>(
|
||||
m,
|
||||
"custom_function",
|
||||
nb::type_slots(py_custom_function_slots),
|
||||
R"pbdoc(
|
||||
Set up a function for custom gradient and vmap definitions.
|
||||
|
||||
@@ -1224,8 +1263,10 @@ void init_transforms(nb::module_& m) {
|
||||
const StrOrSet& argnames) {
|
||||
auto [argnums_vec, argnames_set] =
|
||||
validate_argnums_argnames(argnums, argnames);
|
||||
return nb::cpp_function(py_value_and_grad(
|
||||
fun, argnums_vec, argnames_set, "[value_and_grad]", false));
|
||||
return mlx_func(
|
||||
py_value_and_grad(
|
||||
fun, argnums_vec, argnames_set, "[value_and_grad]", false),
|
||||
fun);
|
||||
},
|
||||
"fun"_a,
|
||||
"argnums"_a = nb::none(),
|
||||
@@ -1290,9 +1331,11 @@ void init_transforms(nb::module_& m) {
|
||||
validate_argnums_argnames(argnums, argnames);
|
||||
auto fn =
|
||||
py_value_and_grad(fun, argnums_vec, argnames_set, "[grad]", true);
|
||||
return nb::cpp_function([fn](nb::args& args, nb::kwargs& kwargs) {
|
||||
return fn(args, kwargs).second;
|
||||
});
|
||||
return mlx_func(
|
||||
[fn = std::move(fn)](nb::args& args, nb::kwargs& kwargs) {
|
||||
return fn(args, kwargs).second;
|
||||
},
|
||||
fun);
|
||||
},
|
||||
"fun"_a,
|
||||
"argnums"_a = nb::none(),
|
||||
@@ -1324,7 +1367,8 @@ void init_transforms(nb::module_& m) {
|
||||
[](const nb::callable& fun,
|
||||
const nb::object& in_axes,
|
||||
const nb::object& out_axes) {
|
||||
return nb::cpp_function(py_vmap(fun, in_axes, out_axes));
|
||||
return mlx_func(
|
||||
py_vmap(fun, in_axes, out_axes), fun, in_axes, out_axes);
|
||||
},
|
||||
"fun"_a,
|
||||
"in_axes"_a = 0,
|
||||
@@ -1379,11 +1423,15 @@ void init_transforms(nb::module_& m) {
|
||||
d.is_none() ? "MLX compiled function." : nb::cast<std::string>(d);
|
||||
|
||||
auto sig_str = sig.str();
|
||||
return nb::cpp_function(
|
||||
PyCompiledFun{fun, inputs, outputs, shapeless},
|
||||
nb::name(name.c_str()),
|
||||
nb::sig(sig_str.c_str()),
|
||||
doc.c_str());
|
||||
return mlx_func(
|
||||
nb::cpp_function(
|
||||
PyCompiledFun{fun, inputs, outputs, shapeless},
|
||||
nb::name(name.c_str()),
|
||||
nb::sig(sig_str.c_str()),
|
||||
doc.c_str()),
|
||||
fun,
|
||||
inputs,
|
||||
outputs);
|
||||
},
|
||||
"fun"_a,
|
||||
"inputs"_a = nb::none(),
|
||||
@@ -1435,7 +1483,7 @@ void init_transforms(nb::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"checkpoint",
|
||||
[](nb::callable fun) { return nb::cpp_function(PyCheckpointedFun{fun}); },
|
||||
[](nb::callable fun) { return mlx_func(PyCheckpointedFun{fun}, fun); },
|
||||
"fun"_a);
|
||||
|
||||
// Register static Python object cleanup before the interpreter exits
|
||||
|
||||
Reference in New Issue
Block a user