Cycle leak break (#1856)

* detect and break leaks in custom function

* detect and break leaks in custom function
This commit is contained in:
Awni Hannun
2025-02-11 14:45:02 -08:00
committed by GitHub
parent 142b77751d
commit 2a45056ba8
10 changed files with 396 additions and 49 deletions

View File

@@ -60,8 +60,56 @@ validate_and_extract_inputs(
return {args_, kwargs_};
}
auto wrap_export_function(const nb::callable& fun) {
return [fun](
int py_function_exporter_tp_traverse(
PyObject* self,
visitproc visit,
void* arg);
class PyFunctionExporter {
public:
PyFunctionExporter(mx::FunctionExporter exporter, nb::handle dep)
: exporter_(std::move(exporter)), dep_(dep) {}
~PyFunctionExporter() {
nb::gil_scoped_acquire gil;
}
PyFunctionExporter(const PyFunctionExporter&) = delete;
PyFunctionExporter& operator=(const PyFunctionExporter&) = delete;
PyFunctionExporter& operator=(const PyFunctionExporter&&) = delete;
PyFunctionExporter(PyFunctionExporter&& other)
: exporter_(std::move(other.exporter_)), dep_(std::move(other.dep_)) {}
void close() {
exporter_.close();
}
void operator()(
const std::vector<mx::array>& args,
const std::map<std::string, mx::array>& kwargs) {
exporter_(args, kwargs);
}
friend int py_function_exporter_tp_traverse(PyObject*, visitproc, void*);
private:
mx::FunctionExporter exporter_;
nb::handle dep_;
};
int py_function_exporter_tp_traverse(
PyObject* self,
visitproc visit,
void* arg) {
auto* p = nb::inst_ptr<PyFunctionExporter>(self);
Py_VISIT(p->dep_.ptr());
Py_VISIT(Py_TYPE(self));
return 0;
}
PyType_Slot py_function_exporter_slots[] = {
{Py_tp_traverse, (void*)py_function_exporter_tp_traverse},
{0, 0}};
auto wrap_export_function(nb::callable fun) {
return [fun = std::move(fun)](
const std::vector<mx::array>& args_,
const std::map<std::string, mx::array>& kwargs_) {
auto kwargs = nb::dict();
@@ -173,21 +221,21 @@ void init_export(nb::module_& m) {
>>> out = fn((a, b), {"x": x, "y": y}[0]
)pbdoc");
nb::class_<mx::FunctionExporter>(
nb::class_<PyFunctionExporter>(
m,
"FunctionExporter",
nb::type_slots(py_function_exporter_slots),
R"pbdoc(
A context managing class for exporting multiple traces of the same
function to a file.
Make an instance of this class by calling fun:`mx.exporter`.
)pbdoc")
.def("close", &mx::FunctionExporter::close)
.def(
"__enter__", [](mx::FunctionExporter& exporter) { return &exporter; })
.def("close", &PyFunctionExporter::close)
.def("__enter__", [](PyFunctionExporter& exporter) { return &exporter; })
.def(
"__exit__",
[](mx::FunctionExporter& exporter,
[](PyFunctionExporter& exporter,
const std::optional<nb::object>&,
const std::optional<nb::object>&,
const std::optional<nb::object>&) { exporter.close(); },
@@ -196,7 +244,7 @@ void init_export(nb::module_& m) {
"traceback"_a = nb::none())
.def(
"__call__",
[](mx::FunctionExporter& exporter,
[](PyFunctionExporter& exporter,
const nb::args& args,
const nb::kwargs& kwargs) {
auto [args_, kwargs_] =
@@ -206,8 +254,9 @@ void init_export(nb::module_& m) {
m.def(
"exporter",
[](const std::string& file, const nb::callable& fun, bool shapeless) {
return mx::exporter(file, wrap_export_function(fun), shapeless);
[](const std::string& file, nb::callable fun, bool shapeless) {
return PyFunctionExporter{
mx::exporter(file, wrap_export_function(fun), shapeless), fun};
},
"file"_a,
"fun"_a,