Fixes segfault when compiling checkpointed functions (#1235)

This commit is contained in:
Angelos Katharopoulos 2024-06-26 16:14:45 -07:00 committed by GitHub
parent 2615660e62
commit b05bcfd27f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 16 additions and 2 deletions

View File

@ -266,6 +266,10 @@ class CompilerCache {
cache_.erase(fun_id);
}
void clear() {
cache_.clear();
}
private:
CompilerCache() {
// Make sure the allocator is fully
@ -859,6 +863,10 @@ void compile_erase(std::uintptr_t fun_id) {
detail::compiler_cache().erase(fun_id);
}
void compile_clear_cache() {
detail::compiler_cache().clear();
}
} // namespace detail
std::function<std::vector<array>(const std::vector<array>&)> compile(

View File

@ -27,6 +27,10 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
// Erase cached compile functions
void compile_erase(std::uintptr_t fun_id);
// Clear the compiler cache causing a recompilation of all compiled functions
// when called again.
void compile_clear_cache();
// Create an InTracing object during tracing operations to signify to the rest
// of the codebase that we are during tracing so evals should not throw away
// the graph.

View File

@ -536,7 +536,6 @@ struct PyCompiledFun {
class PyCheckpointedFun {
public:
PyCheckpointedFun(nb::callable fun) : fun_(std::move(fun)) {}
~PyCheckpointedFun() {
nb::gil_scoped_acquire gil;
@ -968,5 +967,8 @@ void init_transforms(nb::module_& m) {
// Register static Python object cleanup before the interpreter exits
auto atexit = nb::module_::import_("atexit");
atexit.attr("register")(nb::cpp_function([]() { tree_cache().clear(); }));
atexit.attr("register")(nb::cpp_function([]() {
tree_cache().clear();
detail::compile_clear_cache();
}));
}