mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 10:41:14 +08:00
Fixes segfault when compiling checkpointed functions (#1235)
This commit is contained in:
parent
2615660e62
commit
b05bcfd27f
@ -266,6 +266,10 @@ class CompilerCache {
|
|||||||
cache_.erase(fun_id);
|
cache_.erase(fun_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void clear() {
|
||||||
|
cache_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
CompilerCache() {
|
CompilerCache() {
|
||||||
// Make sure the allocator is fully
|
// Make sure the allocator is fully
|
||||||
@ -859,6 +863,10 @@ void compile_erase(std::uintptr_t fun_id) {
|
|||||||
detail::compiler_cache().erase(fun_id);
|
detail::compiler_cache().erase(fun_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void compile_clear_cache() {
|
||||||
|
detail::compiler_cache().clear();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
|
@ -27,6 +27,10 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
|||||||
// Erase cached compile functions
|
// Erase cached compile functions
|
||||||
void compile_erase(std::uintptr_t fun_id);
|
void compile_erase(std::uintptr_t fun_id);
|
||||||
|
|
||||||
|
// Clear the compiler cache causing a recompilation of all compiled functions
|
||||||
|
// when called again.
|
||||||
|
void compile_clear_cache();
|
||||||
|
|
||||||
// Create an InTracing object during tracing operations to signify to the rest
|
// Create an InTracing object during tracing operations to signify to the rest
|
||||||
// of the codebase that we are during tracing so evals should not throw away
|
// of the codebase that we are during tracing so evals should not throw away
|
||||||
// the graph.
|
// the graph.
|
||||||
|
@ -536,7 +536,6 @@ struct PyCompiledFun {
|
|||||||
class PyCheckpointedFun {
|
class PyCheckpointedFun {
|
||||||
public:
|
public:
|
||||||
PyCheckpointedFun(nb::callable fun) : fun_(std::move(fun)) {}
|
PyCheckpointedFun(nb::callable fun) : fun_(std::move(fun)) {}
|
||||||
|
|
||||||
~PyCheckpointedFun() {
|
~PyCheckpointedFun() {
|
||||||
nb::gil_scoped_acquire gil;
|
nb::gil_scoped_acquire gil;
|
||||||
|
|
||||||
@ -968,5 +967,8 @@ void init_transforms(nb::module_& m) {
|
|||||||
|
|
||||||
// Register static Python object cleanup before the interpreter exits
|
// Register static Python object cleanup before the interpreter exits
|
||||||
auto atexit = nb::module_::import_("atexit");
|
auto atexit = nb::module_::import_("atexit");
|
||||||
atexit.attr("register")(nb::cpp_function([]() { tree_cache().clear(); }));
|
atexit.attr("register")(nb::cpp_function([]() {
|
||||||
|
tree_cache().clear();
|
||||||
|
detail::compile_clear_cache();
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user