From 2a45056ba867c22a047dbc69c2390446eca3a555 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 11 Feb 2025 14:45:02 -0800 Subject: [PATCH] Cycle leak break (#1856) * detect and break leaks in custom function * detect and break leaks in custom function --- python/src/CMakeLists.txt | 1 + python/src/export.cpp | 69 +++++++++++++--- python/src/mlx.cpp | 2 + python/src/mlx_func.cpp | 108 +++++++++++++++++++++++++ python/src/mlx_func.h | 24 ++++++ python/src/transforms.cpp | 126 ++++++++++++++++++++--------- python/tests/test_autograd.py | 33 ++++++++ python/tests/test_compile.py | 27 +++++++ python/tests/test_export_import.py | 28 +++++++ python/tests/test_vmap.py | 27 +++++++ 10 files changed, 396 insertions(+), 49 deletions(-) create mode 100644 python/src/mlx_func.cpp create mode 100644 python/src/mlx_func.h diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt index a82c6bb6c..caaa478a3 100644 --- a/python/src/CMakeLists.txt +++ b/python/src/CMakeLists.txt @@ -17,6 +17,7 @@ nanobind_add_module( ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/mlx_func.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp ${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp diff --git a/python/src/export.cpp b/python/src/export.cpp index b2088587a..feefeb12c 100644 --- a/python/src/export.cpp +++ b/python/src/export.cpp @@ -60,8 +60,56 @@ validate_and_extract_inputs( return {args_, kwargs_}; } -auto wrap_export_function(const nb::callable& fun) { - return [fun]( +int py_function_exporter_tp_traverse( + PyObject* self, + visitproc visit, + void* arg); + +class PyFunctionExporter { + public: + PyFunctionExporter(mx::FunctionExporter exporter, nb::handle dep) + : exporter_(std::move(exporter)), dep_(dep) {} + ~PyFunctionExporter() { + nb::gil_scoped_acquire gil; + } + PyFunctionExporter(const PyFunctionExporter&) = delete; + PyFunctionExporter& operator=(const PyFunctionExporter&) = delete; + PyFunctionExporter& operator=(const PyFunctionExporter&&) = delete; + PyFunctionExporter(PyFunctionExporter&& other) + : exporter_(std::move(other.exporter_)), dep_(std::move(other.dep_)) {} + + void close() { + exporter_.close(); + } + void operator()( + const std::vector& args, + const std::map& kwargs) { + exporter_(args, kwargs); + } + + friend int py_function_exporter_tp_traverse(PyObject*, visitproc, void*); + + private: + mx::FunctionExporter exporter_; + nb::handle dep_; +}; + +int py_function_exporter_tp_traverse( + PyObject* self, + visitproc visit, + void* arg) { + auto* p = nb::inst_ptr(self); + Py_VISIT(p->dep_.ptr()); + Py_VISIT(Py_TYPE(self)); + return 0; +} + +PyType_Slot py_function_exporter_slots[] = { + {Py_tp_traverse, (void*)py_function_exporter_tp_traverse}, + {0, 0}}; + +auto wrap_export_function(nb::callable fun) { + return [fun = std::move(fun)]( const std::vector& args_, const std::map& kwargs_) { auto kwargs = nb::dict(); @@ -173,21 +221,21 @@ void init_export(nb::module_& m) { >>> out = fn((a, b), {"x": x, "y": y}[0] )pbdoc"); - nb::class_( + nb::class_( m, "FunctionExporter", + nb::type_slots(py_function_exporter_slots), R"pbdoc( A context managing class for exporting multiple traces of the same function to a file. Make an instance of this class by calling fun:`mx.exporter`. )pbdoc") - .def("close", &mx::FunctionExporter::close) - .def( - "__enter__", [](mx::FunctionExporter& exporter) { return &exporter; }) + .def("close", &PyFunctionExporter::close) + .def("__enter__", [](PyFunctionExporter& exporter) { return &exporter; }) .def( "__exit__", - [](mx::FunctionExporter& exporter, + [](PyFunctionExporter& exporter, const std::optional&, const std::optional&, const std::optional&) { exporter.close(); }, @@ -196,7 +244,7 @@ void init_export(nb::module_& m) { "traceback"_a = nb::none()) .def( "__call__", - [](mx::FunctionExporter& exporter, + [](PyFunctionExporter& exporter, const nb::args& args, const nb::kwargs& kwargs) { auto [args_, kwargs_] = @@ -206,8 +254,9 @@ void init_export(nb::module_& m) { m.def( "exporter", - [](const std::string& file, const nb::callable& fun, bool shapeless) { - return mx::exporter(file, wrap_export_function(fun), shapeless); + [](const std::string& file, nb::callable fun, bool shapeless) { + return PyFunctionExporter{ + mx::exporter(file, wrap_export_function(fun), shapeless), fun}; }, "file"_a, "fun"_a, diff --git a/python/src/mlx.cpp b/python/src/mlx.cpp index 5ef7863fe..ecf9a3a13 100644 --- a/python/src/mlx.cpp +++ b/python/src/mlx.cpp @@ -7,6 +7,7 @@ namespace nb = nanobind; +void init_mlx_func(nb::module_&); void init_array(nb::module_&); void init_device(nb::module_&); void init_stream(nb::module_&); @@ -28,6 +29,7 @@ NB_MODULE(core, m) { nb::module_::import_("mlx._os_warning"); nb::set_leak_warnings(false); + init_mlx_func(m); init_device(m); init_stream(m); init_array(m); diff --git a/python/src/mlx_func.cpp b/python/src/mlx_func.cpp new file mode 100644 index 000000000..b2eca5f6f --- /dev/null +++ b/python/src/mlx_func.cpp @@ -0,0 +1,108 @@ +// Copyright © 2025 Apple Inc. + +#include "python/src/mlx_func.h" + +// A garbage collected function which wraps nb::cpp_function +// See https://github.com/wjakob/nanobind/discussions/919 + +struct gc_func { + PyObject_HEAD + // Vector call implementation that forwards calls to nanobind + PyObject* (*vectorcall)(PyObject*, PyObject* const*, size_t, PyObject*); + // The function itself + PyObject* func; + // A non-owning reference to dependencies owned by 'func' + std::vector deps; +}; + +int gc_func_tp_traverse(PyObject* self, visitproc visit, void* arg) { + gc_func* w = (gc_func*)self; + Py_VISIT(w->func); + for (auto d : w->deps) { + Py_VISIT(d); + } + Py_VISIT(Py_TYPE(self)); + return 0; +}; + +int gc_func_tp_clear(PyObject* self) { + gc_func* w = (gc_func*)self; + Py_CLEAR(w->func); + return 0; +} + +PyObject* gc_func_get_doc(PyObject* self, void*) { + return PyObject_GetAttrString(((gc_func*)self)->func, "__doc__"); +} + +PyObject* gc_func_get_sig(PyObject* self, void*) { + return PyObject_GetAttrString(((gc_func*)self)->func, "__nb_signature__"); +} + +PyObject* gc_func_vectorcall( + PyObject* self, + PyObject* const* args, + size_t nargs, + PyObject* kwnames) { + return PyObject_Vectorcall(((gc_func*)self)->func, args, nargs, kwnames); +} + +void gc_func_dealloc(PyObject* self) { + PyObject_GC_UnTrack(self); + Py_XDECREF(((gc_func*)self)->func); + PyObject_GC_Del(self); +} + +static PyMemberDef gc_func_members[] = { + {"__vectorcalloffset__", + T_PYSSIZET, + (Py_ssize_t)offsetof(gc_func, vectorcall), + READONLY, + nullptr}, + {nullptr, 0, 0, 0, nullptr}}; + +static PyGetSetDef gc_func_getset[] = { + {"__doc__", gc_func_get_doc, nullptr, nullptr, nullptr}, + {"__nb_signature__", gc_func_get_sig, nullptr, nullptr, nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr}}; + +static PyObject* gc_func_getattro(PyObject* self, PyObject* name_) { + gc_func* w = (gc_func*)self; + auto f = PyCFunction(PyType_GetSlot(Py_TYPE(w->func), Py_tp_getattro)); + return f(w->func, name_); +} + +// Table of custom type slots we want to install +PyType_Slot gc_func_slots[] = { + {Py_tp_traverse, (void*)gc_func_tp_traverse}, + {Py_tp_clear, (void*)gc_func_tp_clear}, + {Py_tp_getset, (void*)gc_func_getset}, + {Py_tp_getattro, (void*)gc_func_getattro}, + {Py_tp_members, (void*)gc_func_members}, + {Py_tp_call, (void*)PyVectorcall_Call}, + {Py_tp_dealloc, (void*)gc_func_dealloc}, + {0, 0}}; + +static PyType_Spec gc_func_spec = { + /* .name = */ "mlx.gc_func", + /* .basicsize = */ (int)sizeof(gc_func), + /* .itemsize = */ 0, + /* .flags = */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | NB_HAVE_VECTORCALL, + /* .slots = */ gc_func_slots}; + +static PyTypeObject* gc_func_tp = nullptr; + +nb::callable mlx_func(nb::object func, std::vector deps) { + gc_func* r = (gc_func*)PyType_GenericAlloc(gc_func_tp, 0); + r->func = func.inc_ref().ptr(); + r->deps = std::move(deps); + r->vectorcall = gc_func_vectorcall; + return nb::steal((PyObject*)r); +} + +void init_mlx_func(nb::module_& m) { + gc_func_tp = (PyTypeObject*)PyType_FromSpec(&gc_func_spec); + if (!gc_func_tp) { + nb::raise("Could not register MLX function type."); + } +} diff --git a/python/src/mlx_func.h b/python/src/mlx_func.h new file mode 100644 index 000000000..90cfb57be --- /dev/null +++ b/python/src/mlx_func.h @@ -0,0 +1,24 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +#include +#include + +namespace nb = nanobind; +using namespace nb::literals; + +nb::callable mlx_func(nb::object func, std::vector deps); + +template +nb::callable mlx_func(F func, Deps&&... deps) { + return mlx_func( + nb::cpp_function(std::move(func)), std::vector{deps.ptr()...}); +} + +template +nb::callable mlx_func(nb::object func, Deps&&... deps) { + return mlx_func(std::move(func), std::vector{deps.ptr()...}); +} diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index e351b9f68..8585bd378 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -19,6 +19,7 @@ #include "mlx/transforms.h" #include "mlx/transforms_impl.h" #include "mlx/utils.h" +#include "python/src/mlx_func.h" #include "python/src/trees.h" namespace mx = mlx::core; @@ -543,9 +544,9 @@ struct PyCompiledFun { tree_cache().erase(fun_id); mx::detail::compile_erase(fun_id); - fun.release().dec_ref(); - captured_inputs.release().dec_ref(); - captured_outputs.release().dec_ref(); + fun.reset(); + captured_inputs.reset(); + captured_outputs.reset(); } }; @@ -555,7 +556,7 @@ class PyCheckpointedFun { ~PyCheckpointedFun() { nb::gil_scoped_acquire gil; - fun_.release().dec_ref(); + fun_.reset(); } struct InnerFunction { @@ -573,8 +574,8 @@ class PyCheckpointedFun { ~InnerFunction() { nb::gil_scoped_acquire gil; - fun_.release().dec_ref(); - args_structure_.release().dec_ref(); + fun_.reset(); + args_structure_.reset(); } std::vector operator()(const std::vector& inputs) { @@ -609,6 +610,10 @@ class PyCheckpointedFun { nb::callable fun_; }; +int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg); + +int py_custom_function_tp_clear(PyObject* self); + /** * PyCustomFunction is the class that implements the python decorator * `mx.custom_function`. @@ -641,17 +646,7 @@ class PyCustomFunction { PyCustomFunction(nb::callable fun) : fun_(std::move(fun)) {} ~PyCustomFunction() { nb::gil_scoped_acquire gil; - - fun_.release().dec_ref(); - if (vjp_fun_.has_value()) { - (*vjp_fun_).release().dec_ref(); - } - if (jvp_fun_.has_value()) { - (*jvp_fun_).release().dec_ref(); - } - if (vmap_fun_.has_value()) { - (*vmap_fun_).release().dec_ref(); - } + reset(); } struct InnerFunction { @@ -669,10 +664,10 @@ class PyCustomFunction { ~InnerFunction() { nb::gil_scoped_acquire gil; - fun_.release().dec_ref(); - input_structure_.release().dec_ref(); + fun_.reset(); + input_structure_.reset(); if (output_structure_.use_count() == 1) { - output_structure_->release().dec_ref(); + output_structure_->reset(); } } @@ -703,10 +698,10 @@ class PyCustomFunction { ~InnerVJPFunction() { nb::gil_scoped_acquire gil; - vjp_fun_.release().dec_ref(); - input_structure_.release().dec_ref(); + vjp_fun_.reset(); + input_structure_.reset(); if (output_structure_.use_count() == 1) { - output_structure_->release().dec_ref(); + output_structure_->reset(); } } @@ -746,8 +741,8 @@ class PyCustomFunction { ~InnerJVPFunction() { nb::gil_scoped_acquire gil; - jvp_fun_.release().dec_ref(); - input_structure_.release().dec_ref(); + jvp_fun_.reset(); + input_structure_.reset(); } std::vector operator()( @@ -801,8 +796,8 @@ class PyCustomFunction { ~InnerVmapFunction() { nb::gil_scoped_acquire gil; - vmap_fun_.release().dec_ref(); - input_structure_.release().dec_ref(); + vmap_fun_.reset(); + input_structure_.reset(); } std::pair, std::vector> operator()( @@ -904,6 +899,20 @@ class PyCustomFunction { vmap_fun_ = vmap_fun; return *this; } + void reset() { + fun_.reset(); + if (vjp_fun_.has_value()) { + (*vjp_fun_).reset(); + } + if (jvp_fun_.has_value()) { + (*jvp_fun_).reset(); + } + if (vmap_fun_.has_value()) { + (*vmap_fun_).reset(); + } + } + + friend int py_custom_function_tp_traverse(PyObject*, visitproc, void*); private: std::optional make_vjp_function( @@ -940,10 +949,40 @@ class PyCustomFunction { std::optional vmap_fun_; }; +int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg) { + auto* p = nb::inst_ptr(self); + nb::handle v = nb::find(p->fun_); + Py_VISIT(v.ptr()); + if (p->vjp_fun_.has_value()) { + nb::handle v = nb::find(*(p->vjp_fun_)); + Py_VISIT(v.ptr()); + } + if (p->jvp_fun_.has_value()) { + nb::handle v = nb::find(*(p->jvp_fun_)); + Py_VISIT(v.ptr()); + } + if (p->vmap_fun_.has_value()) { + nb::handle v = nb::find(*(p->vmap_fun_)); + Py_VISIT(v.ptr()); + } + Py_VISIT(Py_TYPE(self)); + return 0; +} +int py_custom_function_tp_clear(PyObject* self) { + auto* p = nb::inst_ptr(self); + p->reset(); + return 0; +} +PyType_Slot py_custom_function_slots[] = { + {Py_tp_traverse, (void*)py_custom_function_tp_traverse}, + {Py_tp_clear, (void*)py_custom_function_tp_clear}, + {0, 0}}; + void init_transforms(nb::module_& m) { nb::class_( m, "custom_function", + nb::type_slots(py_custom_function_slots), R"pbdoc( Set up a function for custom gradient and vmap definitions. @@ -1224,8 +1263,10 @@ void init_transforms(nb::module_& m) { const StrOrSet& argnames) { auto [argnums_vec, argnames_set] = validate_argnums_argnames(argnums, argnames); - return nb::cpp_function(py_value_and_grad( - fun, argnums_vec, argnames_set, "[value_and_grad]", false)); + return mlx_func( + py_value_and_grad( + fun, argnums_vec, argnames_set, "[value_and_grad]", false), + fun); }, "fun"_a, "argnums"_a = nb::none(), @@ -1290,9 +1331,11 @@ void init_transforms(nb::module_& m) { validate_argnums_argnames(argnums, argnames); auto fn = py_value_and_grad(fun, argnums_vec, argnames_set, "[grad]", true); - return nb::cpp_function([fn](nb::args& args, nb::kwargs& kwargs) { - return fn(args, kwargs).second; - }); + return mlx_func( + [fn = std::move(fn)](nb::args& args, nb::kwargs& kwargs) { + return fn(args, kwargs).second; + }, + fun); }, "fun"_a, "argnums"_a = nb::none(), @@ -1324,7 +1367,8 @@ void init_transforms(nb::module_& m) { [](const nb::callable& fun, const nb::object& in_axes, const nb::object& out_axes) { - return nb::cpp_function(py_vmap(fun, in_axes, out_axes)); + return mlx_func( + py_vmap(fun, in_axes, out_axes), fun, in_axes, out_axes); }, "fun"_a, "in_axes"_a = 0, @@ -1379,11 +1423,15 @@ void init_transforms(nb::module_& m) { d.is_none() ? "MLX compiled function." : nb::cast(d); auto sig_str = sig.str(); - return nb::cpp_function( - PyCompiledFun{fun, inputs, outputs, shapeless}, - nb::name(name.c_str()), - nb::sig(sig_str.c_str()), - doc.c_str()); + return mlx_func( + nb::cpp_function( + PyCompiledFun{fun, inputs, outputs, shapeless}, + nb::name(name.c_str()), + nb::sig(sig_str.c_str()), + doc.c_str()), + fun, + inputs, + outputs); }, "fun"_a, "inputs"_a = nb::none(), @@ -1435,7 +1483,7 @@ void init_transforms(nb::module_& m) { )pbdoc"); m.def( "checkpoint", - [](nb::callable fun) { return nb::cpp_function(PyCheckpointedFun{fun}); }, + [](nb::callable fun) { return mlx_func(PyCheckpointedFun{fun}, fun); }, "fun"_a); // Register static Python object cleanup before the interpreter exits diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 3c226365f..3ec020270 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -1,5 +1,6 @@ # Copyright © 2023 Apple Inc. +import gc import unittest import mlx.core as mx @@ -737,6 +738,38 @@ class TestAutograd(mlx_tests.MLXTestCase): expected[4:-5:-2] = tan_b self.assertTrue(mx.allclose(grad, expected)) + def test_leaks(self): + for transform in [ + mx.grad, + mx.value_and_grad, + mx.custom_function, + mx.checkpoint, + ]: + if mx.metal.is_available(): + mem_pre = mx.metal.get_active_memory() + else: + mem_pre = 0 + + def outer(): + d = {} + + def f(x): + return d["x"] + + d["f"] = transform(f) + d["x"] = mx.array([0] * 1000) + + for _ in range(5): + outer() + gc.collect() + + if mx.metal.is_available(): + mem_post = mx.metal.get_active_memory() + else: + mem_post = 0 + + self.assertEqual(mem_pre, mem_post) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 1974b9a23..ba6b316ce 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -1,5 +1,6 @@ # Copyright © 2023-2024 Apple Inc. +import gc import io import unittest from functools import partial @@ -926,6 +927,32 @@ class TestCompile(mlx_tests.MLXTestCase): self.assertEqual(out[0].shape, (3, 1, 4, 2)) self.assertEqual(out[1].shape, (2, 2, 5)) + def test_leaks(self): + if mx.metal.is_available(): + mem_pre = mx.metal.get_active_memory() + else: + mem_pre = 0 + + def outer(): + d = {} + + def f(x): + return d["x"] + + d["f"] = mx.compile(f) + d["x"] = mx.array([0] * 1000) + + for _ in range(5): + outer() + gc.collect() + + if mx.metal.is_available(): + mem_post = mx.metal.get_active_memory() + else: + mem_post = 0 + + self.assertEqual(mem_pre, mem_post) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_export_import.py b/python/tests/test_export_import.py index fca89329d..fd62a58f6 100644 --- a/python/tests/test_export_import.py +++ b/python/tests/test_export_import.py @@ -1,5 +1,6 @@ # Copyright © 2024 Apple Inc. +import gc import os import tempfile import unittest @@ -239,6 +240,33 @@ class TestExportImport(mlx_tests.MLXTestCase): constants_size = constant.nbytes + 8192 self.assertTrue(os.path.getsize(path) < constants_size) + def test_leaks(self): + path = os.path.join(self.test_dir, "fn.mlxfn") + if mx.metal.is_available(): + mem_pre = mx.metal.get_active_memory() + else: + mem_pre = 0 + + def outer(): + d = {} + + def f(x): + return d["x"] + + d["f"] = mx.exporter(path, f) + d["x"] = mx.array([0] * 1000) + + for _ in range(5): + outer() + gc.collect() + + if mx.metal.is_available(): + mem_post = mx.metal.get_active_memory() + else: + mem_post = 0 + + self.assertEqual(mem_pre, mem_post) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index ceadf36a4..2d38bc457 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -1,5 +1,6 @@ # Copyright © 2023-2024 Apple Inc. +import gc import unittest import mlx.core as mx @@ -608,6 +609,32 @@ class TestVmap(mlx_tests.MLXTestCase): self.assertEqual(fx.shape, (5, 6, 7)) self.assertEqual(fy.shape, (4, 5, 6, 7)) + def test_leaks(self): + if mx.metal.is_available(): + mem_pre = mx.metal.get_active_memory() + else: + mem_pre = 0 + + def outer(): + d = {} + + def f(x): + return d["x"] + + d["f"] = mx.vmap(f) + d["x"] = mx.array([0] * 1000) + + for _ in range(5): + outer() + gc.collect() + + if mx.metal.is_available(): + mem_post = mx.metal.get_active_memory() + else: + mem_post = 0 + + self.assertEqual(mem_pre, mem_post) + if __name__ == "__main__": unittest.main()