Cycle leak break (#1856)

* detect and break leaks in custom function

* detect and break leaks in custom function
This commit is contained in:
Awni Hannun 2025-02-11 14:45:02 -08:00 committed by GitHub
parent 142b77751d
commit 2a45056ba8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 396 additions and 49 deletions

View File

@ -17,6 +17,7 @@ nanobind_add_module(
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mlx_func.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp

View File

@ -60,8 +60,56 @@ validate_and_extract_inputs(
return {args_, kwargs_};
}
auto wrap_export_function(const nb::callable& fun) {
return [fun](
int py_function_exporter_tp_traverse(
PyObject* self,
visitproc visit,
void* arg);
class PyFunctionExporter {
public:
PyFunctionExporter(mx::FunctionExporter exporter, nb::handle dep)
: exporter_(std::move(exporter)), dep_(dep) {}
~PyFunctionExporter() {
nb::gil_scoped_acquire gil;
}
PyFunctionExporter(const PyFunctionExporter&) = delete;
PyFunctionExporter& operator=(const PyFunctionExporter&) = delete;
PyFunctionExporter& operator=(const PyFunctionExporter&&) = delete;
PyFunctionExporter(PyFunctionExporter&& other)
: exporter_(std::move(other.exporter_)), dep_(std::move(other.dep_)) {}
void close() {
exporter_.close();
}
void operator()(
const std::vector<mx::array>& args,
const std::map<std::string, mx::array>& kwargs) {
exporter_(args, kwargs);
}
friend int py_function_exporter_tp_traverse(PyObject*, visitproc, void*);
private:
mx::FunctionExporter exporter_;
nb::handle dep_;
};
int py_function_exporter_tp_traverse(
PyObject* self,
visitproc visit,
void* arg) {
auto* p = nb::inst_ptr<PyFunctionExporter>(self);
Py_VISIT(p->dep_.ptr());
Py_VISIT(Py_TYPE(self));
return 0;
}
PyType_Slot py_function_exporter_slots[] = {
{Py_tp_traverse, (void*)py_function_exporter_tp_traverse},
{0, 0}};
auto wrap_export_function(nb::callable fun) {
return [fun = std::move(fun)](
const std::vector<mx::array>& args_,
const std::map<std::string, mx::array>& kwargs_) {
auto kwargs = nb::dict();
@ -173,21 +221,21 @@ void init_export(nb::module_& m) {
>>> out = fn((a, b), {"x": x, "y": y}[0]
)pbdoc");
nb::class_<mx::FunctionExporter>(
nb::class_<PyFunctionExporter>(
m,
"FunctionExporter",
nb::type_slots(py_function_exporter_slots),
R"pbdoc(
A context managing class for exporting multiple traces of the same
function to a file.
Make an instance of this class by calling fun:`mx.exporter`.
)pbdoc")
.def("close", &mx::FunctionExporter::close)
.def(
"__enter__", [](mx::FunctionExporter& exporter) { return &exporter; })
.def("close", &PyFunctionExporter::close)
.def("__enter__", [](PyFunctionExporter& exporter) { return &exporter; })
.def(
"__exit__",
[](mx::FunctionExporter& exporter,
[](PyFunctionExporter& exporter,
const std::optional<nb::object>&,
const std::optional<nb::object>&,
const std::optional<nb::object>&) { exporter.close(); },
@ -196,7 +244,7 @@ void init_export(nb::module_& m) {
"traceback"_a = nb::none())
.def(
"__call__",
[](mx::FunctionExporter& exporter,
[](PyFunctionExporter& exporter,
const nb::args& args,
const nb::kwargs& kwargs) {
auto [args_, kwargs_] =
@ -206,8 +254,9 @@ void init_export(nb::module_& m) {
m.def(
"exporter",
[](const std::string& file, const nb::callable& fun, bool shapeless) {
return mx::exporter(file, wrap_export_function(fun), shapeless);
[](const std::string& file, nb::callable fun, bool shapeless) {
return PyFunctionExporter{
mx::exporter(file, wrap_export_function(fun), shapeless), fun};
},
"file"_a,
"fun"_a,

View File

@ -7,6 +7,7 @@
namespace nb = nanobind;
void init_mlx_func(nb::module_&);
void init_array(nb::module_&);
void init_device(nb::module_&);
void init_stream(nb::module_&);
@ -28,6 +29,7 @@ NB_MODULE(core, m) {
nb::module_::import_("mlx._os_warning");
nb::set_leak_warnings(false);
init_mlx_func(m);
init_device(m);
init_stream(m);
init_array(m);

108
python/src/mlx_func.cpp Normal file
View File

@ -0,0 +1,108 @@
// Copyright © 2025 Apple Inc.
#include "python/src/mlx_func.h"
// A garbage collected function which wraps nb::cpp_function
// See https://github.com/wjakob/nanobind/discussions/919
struct gc_func {
PyObject_HEAD
// Vector call implementation that forwards calls to nanobind
PyObject* (*vectorcall)(PyObject*, PyObject* const*, size_t, PyObject*);
// The function itself
PyObject* func;
// A non-owning reference to dependencies owned by 'func'
std::vector<PyObject*> deps;
};
int gc_func_tp_traverse(PyObject* self, visitproc visit, void* arg) {
gc_func* w = (gc_func*)self;
Py_VISIT(w->func);
for (auto d : w->deps) {
Py_VISIT(d);
}
Py_VISIT(Py_TYPE(self));
return 0;
};
int gc_func_tp_clear(PyObject* self) {
gc_func* w = (gc_func*)self;
Py_CLEAR(w->func);
return 0;
}
PyObject* gc_func_get_doc(PyObject* self, void*) {
return PyObject_GetAttrString(((gc_func*)self)->func, "__doc__");
}
PyObject* gc_func_get_sig(PyObject* self, void*) {
return PyObject_GetAttrString(((gc_func*)self)->func, "__nb_signature__");
}
PyObject* gc_func_vectorcall(
PyObject* self,
PyObject* const* args,
size_t nargs,
PyObject* kwnames) {
return PyObject_Vectorcall(((gc_func*)self)->func, args, nargs, kwnames);
}
void gc_func_dealloc(PyObject* self) {
PyObject_GC_UnTrack(self);
Py_XDECREF(((gc_func*)self)->func);
PyObject_GC_Del(self);
}
static PyMemberDef gc_func_members[] = {
{"__vectorcalloffset__",
T_PYSSIZET,
(Py_ssize_t)offsetof(gc_func, vectorcall),
READONLY,
nullptr},
{nullptr, 0, 0, 0, nullptr}};
static PyGetSetDef gc_func_getset[] = {
{"__doc__", gc_func_get_doc, nullptr, nullptr, nullptr},
{"__nb_signature__", gc_func_get_sig, nullptr, nullptr, nullptr},
{nullptr, nullptr, nullptr, nullptr, nullptr}};
static PyObject* gc_func_getattro(PyObject* self, PyObject* name_) {
gc_func* w = (gc_func*)self;
auto f = PyCFunction(PyType_GetSlot(Py_TYPE(w->func), Py_tp_getattro));
return f(w->func, name_);
}
// Table of custom type slots we want to install
PyType_Slot gc_func_slots[] = {
{Py_tp_traverse, (void*)gc_func_tp_traverse},
{Py_tp_clear, (void*)gc_func_tp_clear},
{Py_tp_getset, (void*)gc_func_getset},
{Py_tp_getattro, (void*)gc_func_getattro},
{Py_tp_members, (void*)gc_func_members},
{Py_tp_call, (void*)PyVectorcall_Call},
{Py_tp_dealloc, (void*)gc_func_dealloc},
{0, 0}};
static PyType_Spec gc_func_spec = {
/* .name = */ "mlx.gc_func",
/* .basicsize = */ (int)sizeof(gc_func),
/* .itemsize = */ 0,
/* .flags = */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | NB_HAVE_VECTORCALL,
/* .slots = */ gc_func_slots};
static PyTypeObject* gc_func_tp = nullptr;
nb::callable mlx_func(nb::object func, std::vector<PyObject*> deps) {
gc_func* r = (gc_func*)PyType_GenericAlloc(gc_func_tp, 0);
r->func = func.inc_ref().ptr();
r->deps = std::move(deps);
r->vectorcall = gc_func_vectorcall;
return nb::steal<nb::callable>((PyObject*)r);
}
void init_mlx_func(nb::module_& m) {
gc_func_tp = (PyTypeObject*)PyType_FromSpec(&gc_func_spec);
if (!gc_func_tp) {
nb::raise("Could not register MLX function type.");
}
}

24
python/src/mlx_func.h Normal file
View File

@ -0,0 +1,24 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <vector>
#include <nanobind/nanobind.h>
#include <nanobind/stl/function.h>
namespace nb = nanobind;
using namespace nb::literals;
nb::callable mlx_func(nb::object func, std::vector<PyObject*> deps);
template <typename F, typename... Deps>
nb::callable mlx_func(F func, Deps&&... deps) {
return mlx_func(
nb::cpp_function(std::move(func)), std::vector<PyObject*>{deps.ptr()...});
}
template <typename... Deps>
nb::callable mlx_func(nb::object func, Deps&&... deps) {
return mlx_func(std::move(func), std::vector<PyObject*>{deps.ptr()...});
}

View File

@ -19,6 +19,7 @@
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
#include "mlx/utils.h"
#include "python/src/mlx_func.h"
#include "python/src/trees.h"
namespace mx = mlx::core;
@ -543,9 +544,9 @@ struct PyCompiledFun {
tree_cache().erase(fun_id);
mx::detail::compile_erase(fun_id);
fun.release().dec_ref();
captured_inputs.release().dec_ref();
captured_outputs.release().dec_ref();
fun.reset();
captured_inputs.reset();
captured_outputs.reset();
}
};
@ -555,7 +556,7 @@ class PyCheckpointedFun {
~PyCheckpointedFun() {
nb::gil_scoped_acquire gil;
fun_.release().dec_ref();
fun_.reset();
}
struct InnerFunction {
@ -573,8 +574,8 @@ class PyCheckpointedFun {
~InnerFunction() {
nb::gil_scoped_acquire gil;
fun_.release().dec_ref();
args_structure_.release().dec_ref();
fun_.reset();
args_structure_.reset();
}
std::vector<mx::array> operator()(const std::vector<mx::array>& inputs) {
@ -609,6 +610,10 @@ class PyCheckpointedFun {
nb::callable fun_;
};
int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg);
int py_custom_function_tp_clear(PyObject* self);
/**
* PyCustomFunction is the class that implements the python decorator
* `mx.custom_function`.
@ -641,17 +646,7 @@ class PyCustomFunction {
PyCustomFunction(nb::callable fun) : fun_(std::move(fun)) {}
~PyCustomFunction() {
nb::gil_scoped_acquire gil;
fun_.release().dec_ref();
if (vjp_fun_.has_value()) {
(*vjp_fun_).release().dec_ref();
}
if (jvp_fun_.has_value()) {
(*jvp_fun_).release().dec_ref();
}
if (vmap_fun_.has_value()) {
(*vmap_fun_).release().dec_ref();
}
reset();
}
struct InnerFunction {
@ -669,10 +664,10 @@ class PyCustomFunction {
~InnerFunction() {
nb::gil_scoped_acquire gil;
fun_.release().dec_ref();
input_structure_.release().dec_ref();
fun_.reset();
input_structure_.reset();
if (output_structure_.use_count() == 1) {
output_structure_->release().dec_ref();
output_structure_->reset();
}
}
@ -703,10 +698,10 @@ class PyCustomFunction {
~InnerVJPFunction() {
nb::gil_scoped_acquire gil;
vjp_fun_.release().dec_ref();
input_structure_.release().dec_ref();
vjp_fun_.reset();
input_structure_.reset();
if (output_structure_.use_count() == 1) {
output_structure_->release().dec_ref();
output_structure_->reset();
}
}
@ -746,8 +741,8 @@ class PyCustomFunction {
~InnerJVPFunction() {
nb::gil_scoped_acquire gil;
jvp_fun_.release().dec_ref();
input_structure_.release().dec_ref();
jvp_fun_.reset();
input_structure_.reset();
}
std::vector<mx::array> operator()(
@ -801,8 +796,8 @@ class PyCustomFunction {
~InnerVmapFunction() {
nb::gil_scoped_acquire gil;
vmap_fun_.release().dec_ref();
input_structure_.release().dec_ref();
vmap_fun_.reset();
input_structure_.reset();
}
std::pair<std::vector<mx::array>, std::vector<int>> operator()(
@ -904,6 +899,20 @@ class PyCustomFunction {
vmap_fun_ = vmap_fun;
return *this;
}
void reset() {
fun_.reset();
if (vjp_fun_.has_value()) {
(*vjp_fun_).reset();
}
if (jvp_fun_.has_value()) {
(*jvp_fun_).reset();
}
if (vmap_fun_.has_value()) {
(*vmap_fun_).reset();
}
}
friend int py_custom_function_tp_traverse(PyObject*, visitproc, void*);
private:
std::optional<InnerVJPFunction> make_vjp_function(
@ -940,10 +949,40 @@ class PyCustomFunction {
std::optional<nb::callable> vmap_fun_;
};
int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg) {
auto* p = nb::inst_ptr<PyCustomFunction>(self);
nb::handle v = nb::find(p->fun_);
Py_VISIT(v.ptr());
if (p->vjp_fun_.has_value()) {
nb::handle v = nb::find(*(p->vjp_fun_));
Py_VISIT(v.ptr());
}
if (p->jvp_fun_.has_value()) {
nb::handle v = nb::find(*(p->jvp_fun_));
Py_VISIT(v.ptr());
}
if (p->vmap_fun_.has_value()) {
nb::handle v = nb::find(*(p->vmap_fun_));
Py_VISIT(v.ptr());
}
Py_VISIT(Py_TYPE(self));
return 0;
}
int py_custom_function_tp_clear(PyObject* self) {
auto* p = nb::inst_ptr<PyCustomFunction>(self);
p->reset();
return 0;
}
PyType_Slot py_custom_function_slots[] = {
{Py_tp_traverse, (void*)py_custom_function_tp_traverse},
{Py_tp_clear, (void*)py_custom_function_tp_clear},
{0, 0}};
void init_transforms(nb::module_& m) {
nb::class_<PyCustomFunction>(
m,
"custom_function",
nb::type_slots(py_custom_function_slots),
R"pbdoc(
Set up a function for custom gradient and vmap definitions.
@ -1224,8 +1263,10 @@ void init_transforms(nb::module_& m) {
const StrOrSet& argnames) {
auto [argnums_vec, argnames_set] =
validate_argnums_argnames(argnums, argnames);
return nb::cpp_function(py_value_and_grad(
fun, argnums_vec, argnames_set, "[value_and_grad]", false));
return mlx_func(
py_value_and_grad(
fun, argnums_vec, argnames_set, "[value_and_grad]", false),
fun);
},
"fun"_a,
"argnums"_a = nb::none(),
@ -1290,9 +1331,11 @@ void init_transforms(nb::module_& m) {
validate_argnums_argnames(argnums, argnames);
auto fn =
py_value_and_grad(fun, argnums_vec, argnames_set, "[grad]", true);
return nb::cpp_function([fn](nb::args& args, nb::kwargs& kwargs) {
return mlx_func(
[fn = std::move(fn)](nb::args& args, nb::kwargs& kwargs) {
return fn(args, kwargs).second;
});
},
fun);
},
"fun"_a,
"argnums"_a = nb::none(),
@ -1324,7 +1367,8 @@ void init_transforms(nb::module_& m) {
[](const nb::callable& fun,
const nb::object& in_axes,
const nb::object& out_axes) {
return nb::cpp_function(py_vmap(fun, in_axes, out_axes));
return mlx_func(
py_vmap(fun, in_axes, out_axes), fun, in_axes, out_axes);
},
"fun"_a,
"in_axes"_a = 0,
@ -1379,11 +1423,15 @@ void init_transforms(nb::module_& m) {
d.is_none() ? "MLX compiled function." : nb::cast<std::string>(d);
auto sig_str = sig.str();
return nb::cpp_function(
return mlx_func(
nb::cpp_function(
PyCompiledFun{fun, inputs, outputs, shapeless},
nb::name(name.c_str()),
nb::sig(sig_str.c_str()),
doc.c_str());
doc.c_str()),
fun,
inputs,
outputs);
},
"fun"_a,
"inputs"_a = nb::none(),
@ -1435,7 +1483,7 @@ void init_transforms(nb::module_& m) {
)pbdoc");
m.def(
"checkpoint",
[](nb::callable fun) { return nb::cpp_function(PyCheckpointedFun{fun}); },
[](nb::callable fun) { return mlx_func(PyCheckpointedFun{fun}, fun); },
"fun"_a);
// Register static Python object cleanup before the interpreter exits

View File

@ -1,5 +1,6 @@
# Copyright © 2023 Apple Inc.
import gc
import unittest
import mlx.core as mx
@ -737,6 +738,38 @@ class TestAutograd(mlx_tests.MLXTestCase):
expected[4:-5:-2] = tan_b
self.assertTrue(mx.allclose(grad, expected))
def test_leaks(self):
for transform in [
mx.grad,
mx.value_and_grad,
mx.custom_function,
mx.checkpoint,
]:
if mx.metal.is_available():
mem_pre = mx.metal.get_active_memory()
else:
mem_pre = 0
def outer():
d = {}
def f(x):
return d["x"]
d["f"] = transform(f)
d["x"] = mx.array([0] * 1000)
for _ in range(5):
outer()
gc.collect()
if mx.metal.is_available():
mem_post = mx.metal.get_active_memory()
else:
mem_post = 0
self.assertEqual(mem_pre, mem_post)
if __name__ == "__main__":
unittest.main()

View File

@ -1,5 +1,6 @@
# Copyright © 2023-2024 Apple Inc.
import gc
import io
import unittest
from functools import partial
@ -926,6 +927,32 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertEqual(out[0].shape, (3, 1, 4, 2))
self.assertEqual(out[1].shape, (2, 2, 5))
def test_leaks(self):
if mx.metal.is_available():
mem_pre = mx.metal.get_active_memory()
else:
mem_pre = 0
def outer():
d = {}
def f(x):
return d["x"]
d["f"] = mx.compile(f)
d["x"] = mx.array([0] * 1000)
for _ in range(5):
outer()
gc.collect()
if mx.metal.is_available():
mem_post = mx.metal.get_active_memory()
else:
mem_post = 0
self.assertEqual(mem_pre, mem_post)
if __name__ == "__main__":
unittest.main()

View File

@ -1,5 +1,6 @@
# Copyright © 2024 Apple Inc.
import gc
import os
import tempfile
import unittest
@ -239,6 +240,33 @@ class TestExportImport(mlx_tests.MLXTestCase):
constants_size = constant.nbytes + 8192
self.assertTrue(os.path.getsize(path) < constants_size)
def test_leaks(self):
path = os.path.join(self.test_dir, "fn.mlxfn")
if mx.metal.is_available():
mem_pre = mx.metal.get_active_memory()
else:
mem_pre = 0
def outer():
d = {}
def f(x):
return d["x"]
d["f"] = mx.exporter(path, f)
d["x"] = mx.array([0] * 1000)
for _ in range(5):
outer()
gc.collect()
if mx.metal.is_available():
mem_post = mx.metal.get_active_memory()
else:
mem_post = 0
self.assertEqual(mem_pre, mem_post)
if __name__ == "__main__":
unittest.main()

View File

@ -1,5 +1,6 @@
# Copyright © 2023-2024 Apple Inc.
import gc
import unittest
import mlx.core as mx
@ -608,6 +609,32 @@ class TestVmap(mlx_tests.MLXTestCase):
self.assertEqual(fx.shape, (5, 6, 7))
self.assertEqual(fy.shape, (4, 5, 6, 7))
def test_leaks(self):
if mx.metal.is_available():
mem_pre = mx.metal.get_active_memory()
else:
mem_pre = 0
def outer():
d = {}
def f(x):
return d["x"]
d["f"] = mx.vmap(f)
d["x"] = mx.array([0] * 1000)
for _ in range(5):
outer()
gc.collect()
if mx.metal.is_available():
mem_post = mx.metal.get_active_memory()
else:
mem_post = 0
self.assertEqual(mem_pre, mem_post)
if __name__ == "__main__":
unittest.main()