From b05bcfd27f5f1293401b74dce02e38c8fd7ef66a Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 26 Jun 2024 16:14:45 -0700 Subject: [PATCH] Fixes segfault when compiling checkpointed functions (#1235) --- mlx/compile.cpp | 8 ++++++++ mlx/transforms_impl.h | 4 ++++ python/src/transforms.cpp | 6 ++++-- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index cf92d7728..06d85be16 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -266,6 +266,10 @@ class CompilerCache { cache_.erase(fun_id); } + void clear() { + cache_.clear(); + } + private: CompilerCache() { // Make sure the allocator is fully @@ -859,6 +863,10 @@ void compile_erase(std::uintptr_t fun_id) { detail::compiler_cache().erase(fun_id); } +void compile_clear_cache() { + detail::compiler_cache().clear(); +} + } // namespace detail std::function(const std::vector&)> compile( diff --git a/mlx/transforms_impl.h b/mlx/transforms_impl.h index f81dba24f..76dc0ad84 100644 --- a/mlx/transforms_impl.h +++ b/mlx/transforms_impl.h @@ -27,6 +27,10 @@ std::function(const std::vector&)> compile( // Erase cached compile functions void compile_erase(std::uintptr_t fun_id); +// Clear the compiler cache causing a recompilation of all compiled functions +// when called again. +void compile_clear_cache(); + // Create an InTracing object during tracing operations to signify to the rest // of the codebase that we are during tracing so evals should not throw away // the graph. diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 201591e31..d14c41369 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -536,7 +536,6 @@ struct PyCompiledFun { class PyCheckpointedFun { public: PyCheckpointedFun(nb::callable fun) : fun_(std::move(fun)) {} - ~PyCheckpointedFun() { nb::gil_scoped_acquire gil; @@ -968,5 +967,8 @@ void init_transforms(nb::module_& m) { // Register static Python object cleanup before the interpreter exits auto atexit = nb::module_::import_("atexit"); - atexit.attr("register")(nb::cpp_function([]() { tree_cache().clear(); })); + atexit.attr("register")(nb::cpp_function([]() { + tree_cache().clear(); + detail::compile_clear_cache(); + })); }