mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Cycle leak break (#1856)
* detect and break leaks in custom function * detect and break leaks in custom function
This commit is contained in:
@@ -60,8 +60,56 @@ validate_and_extract_inputs(
|
||||
return {args_, kwargs_};
|
||||
}
|
||||
|
||||
auto wrap_export_function(const nb::callable& fun) {
|
||||
return [fun](
|
||||
int py_function_exporter_tp_traverse(
|
||||
PyObject* self,
|
||||
visitproc visit,
|
||||
void* arg);
|
||||
|
||||
class PyFunctionExporter {
|
||||
public:
|
||||
PyFunctionExporter(mx::FunctionExporter exporter, nb::handle dep)
|
||||
: exporter_(std::move(exporter)), dep_(dep) {}
|
||||
~PyFunctionExporter() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
}
|
||||
PyFunctionExporter(const PyFunctionExporter&) = delete;
|
||||
PyFunctionExporter& operator=(const PyFunctionExporter&) = delete;
|
||||
PyFunctionExporter& operator=(const PyFunctionExporter&&) = delete;
|
||||
PyFunctionExporter(PyFunctionExporter&& other)
|
||||
: exporter_(std::move(other.exporter_)), dep_(std::move(other.dep_)) {}
|
||||
|
||||
void close() {
|
||||
exporter_.close();
|
||||
}
|
||||
void operator()(
|
||||
const std::vector<mx::array>& args,
|
||||
const std::map<std::string, mx::array>& kwargs) {
|
||||
exporter_(args, kwargs);
|
||||
}
|
||||
|
||||
friend int py_function_exporter_tp_traverse(PyObject*, visitproc, void*);
|
||||
|
||||
private:
|
||||
mx::FunctionExporter exporter_;
|
||||
nb::handle dep_;
|
||||
};
|
||||
|
||||
int py_function_exporter_tp_traverse(
|
||||
PyObject* self,
|
||||
visitproc visit,
|
||||
void* arg) {
|
||||
auto* p = nb::inst_ptr<PyFunctionExporter>(self);
|
||||
Py_VISIT(p->dep_.ptr());
|
||||
Py_VISIT(Py_TYPE(self));
|
||||
return 0;
|
||||
}
|
||||
|
||||
PyType_Slot py_function_exporter_slots[] = {
|
||||
{Py_tp_traverse, (void*)py_function_exporter_tp_traverse},
|
||||
{0, 0}};
|
||||
|
||||
auto wrap_export_function(nb::callable fun) {
|
||||
return [fun = std::move(fun)](
|
||||
const std::vector<mx::array>& args_,
|
||||
const std::map<std::string, mx::array>& kwargs_) {
|
||||
auto kwargs = nb::dict();
|
||||
@@ -173,21 +221,21 @@ void init_export(nb::module_& m) {
|
||||
>>> out = fn((a, b), {"x": x, "y": y}[0]
|
||||
)pbdoc");
|
||||
|
||||
nb::class_<mx::FunctionExporter>(
|
||||
nb::class_<PyFunctionExporter>(
|
||||
m,
|
||||
"FunctionExporter",
|
||||
nb::type_slots(py_function_exporter_slots),
|
||||
R"pbdoc(
|
||||
A context managing class for exporting multiple traces of the same
|
||||
function to a file.
|
||||
|
||||
Make an instance of this class by calling fun:`mx.exporter`.
|
||||
)pbdoc")
|
||||
.def("close", &mx::FunctionExporter::close)
|
||||
.def(
|
||||
"__enter__", [](mx::FunctionExporter& exporter) { return &exporter; })
|
||||
.def("close", &PyFunctionExporter::close)
|
||||
.def("__enter__", [](PyFunctionExporter& exporter) { return &exporter; })
|
||||
.def(
|
||||
"__exit__",
|
||||
[](mx::FunctionExporter& exporter,
|
||||
[](PyFunctionExporter& exporter,
|
||||
const std::optional<nb::object>&,
|
||||
const std::optional<nb::object>&,
|
||||
const std::optional<nb::object>&) { exporter.close(); },
|
||||
@@ -196,7 +244,7 @@ void init_export(nb::module_& m) {
|
||||
"traceback"_a = nb::none())
|
||||
.def(
|
||||
"__call__",
|
||||
[](mx::FunctionExporter& exporter,
|
||||
[](PyFunctionExporter& exporter,
|
||||
const nb::args& args,
|
||||
const nb::kwargs& kwargs) {
|
||||
auto [args_, kwargs_] =
|
||||
@@ -206,8 +254,9 @@ void init_export(nb::module_& m) {
|
||||
|
||||
m.def(
|
||||
"exporter",
|
||||
[](const std::string& file, const nb::callable& fun, bool shapeless) {
|
||||
return mx::exporter(file, wrap_export_function(fun), shapeless);
|
||||
[](const std::string& file, nb::callable fun, bool shapeless) {
|
||||
return PyFunctionExporter{
|
||||
mx::exporter(file, wrap_export_function(fun), shapeless), fun};
|
||||
},
|
||||
"file"_a,
|
||||
"fun"_a,
|
||||
|
||||
Reference in New Issue
Block a user