mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:38:09 +08:00
Fixes segfault when compiling checkpointed functions (#1235)
This commit is contained in:

committed by
GitHub

parent
2615660e62
commit
b05bcfd27f
@@ -536,7 +536,6 @@ struct PyCompiledFun {
|
||||
class PyCheckpointedFun {
|
||||
public:
|
||||
PyCheckpointedFun(nb::callable fun) : fun_(std::move(fun)) {}
|
||||
|
||||
~PyCheckpointedFun() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
@@ -968,5 +967,8 @@ void init_transforms(nb::module_& m) {
|
||||
|
||||
// Register static Python object cleanup before the interpreter exits
|
||||
auto atexit = nb::module_::import_("atexit");
|
||||
atexit.attr("register")(nb::cpp_function([]() { tree_cache().clear(); }));
|
||||
atexit.attr("register")(nb::cpp_function([]() {
|
||||
tree_cache().clear();
|
||||
detail::compile_clear_cache();
|
||||
}));
|
||||
}
|
||||
|
Reference in New Issue
Block a user