[CUDA] Fix segfault on exit (#2424)

* fix cuda segfault on exit

* comment
This commit is contained in:
Awni Hannun
2025-07-27 08:08:13 -07:00
committed by GitHub
parent 4ad53414dd
commit b9e88fb976
3 changed files with 13 additions and 2 deletions

View File

@@ -9,7 +9,6 @@
#include <cstdlib>
#include <filesystem>
#include <fstream>
#include <unordered_map>
#include <fmt/format.h>
#include <nvrtc.h>
@@ -330,11 +329,16 @@ CUfunction JitModule::get_kernel(const std::string& kernel_name) {
return it->second;
}
std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
static std::unordered_map<std::string, JitModule> map;
return map;
}
JitModule& get_jit_module(
const mlx::core::Device& device,
const std::string& name,
const KernelBuilder& builder) {
static std::unordered_map<std::string, JitModule> map;
auto& map = get_jit_module_cache();
auto it = map.find(name);
if (it == map.end()) {
it = map.try_emplace(name, cu::device(device), name, builder).first;