From 55935ccae7729e002027cc4d0bab4da55c7d6302 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 18 Apr 2025 12:46:53 -0700 Subject: [PATCH] fix py gc edge case (#2079) --- python/src/export.cpp | 5 ++++- python/src/mlx_func.cpp | 2 +- python/src/transforms.cpp | 6 +++++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/python/src/export.cpp b/python/src/export.cpp index feefeb12c..0f3bbc1b6 100644 --- a/python/src/export.cpp +++ b/python/src/export.cpp @@ -98,9 +98,12 @@ int py_function_exporter_tp_traverse( PyObject* self, visitproc visit, void* arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } auto* p = nb::inst_ptr(self); Py_VISIT(p->dep_.ptr()); - Py_VISIT(Py_TYPE(self)); return 0; } diff --git a/python/src/mlx_func.cpp b/python/src/mlx_func.cpp index b2eca5f6f..2f0589bb6 100644 --- a/python/src/mlx_func.cpp +++ b/python/src/mlx_func.cpp @@ -16,12 +16,12 @@ struct gc_func { }; int gc_func_tp_traverse(PyObject* self, visitproc visit, void* arg) { + Py_VISIT(Py_TYPE(self)); gc_func* w = (gc_func*)self; Py_VISIT(w->func); for (auto d : w->deps) { Py_VISIT(d); } - Py_VISIT(Py_TYPE(self)); return 0; }; diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 4a5e2e6ac..c47942b72 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -960,6 +960,11 @@ class PyCustomFunction { }; int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } + auto* p = nb::inst_ptr(self); nb::handle v = nb::find(p->fun_); Py_VISIT(v.ptr()); @@ -975,7 +980,6 @@ int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg) { nb::handle v = nb::find(*(p->vmap_fun_)); Py_VISIT(v.ptr()); } - Py_VISIT(Py_TYPE(self)); return 0; } int py_custom_function_tp_clear(PyObject* self) {