[CUDA] Fix segfault on exit (#2424)

* fix cuda segfault on exit

* comment
This commit is contained in:
Awni Hannun 2025-07-27 08:08:13 -07:00 committed by GitHub
parent 4ad53414dd
commit b9e88fb976
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 2 deletions

View File

@ -1,6 +1,7 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/jit_module.h"
#include "mlx/backend/cuda/worker.h" #include "mlx/backend/cuda/worker.h"
#include "mlx/utils.h" #include "mlx/utils.h"
@ -54,6 +55,10 @@ Device::Device(int device) : device_(device) {
CHECK_CUBLAS_ERROR(cublasLtCreate(&lt_)); CHECK_CUBLAS_ERROR(cublasLtCreate(&lt_));
// The cudnn handle is used by Convolution. // The cudnn handle is used by Convolution.
CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_)); CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_));
// Initialize the jit module cache here ensures it is not
// unloaded before any evaluation is done
get_jit_module_cache();
} }
Device::~Device() { Device::~Device() {

View File

@ -9,7 +9,6 @@
#include <cstdlib> #include <cstdlib>
#include <filesystem> #include <filesystem>
#include <fstream> #include <fstream>
#include <unordered_map>
#include <fmt/format.h> #include <fmt/format.h>
#include <nvrtc.h> #include <nvrtc.h>
@ -330,11 +329,16 @@ CUfunction JitModule::get_kernel(const std::string& kernel_name) {
return it->second; return it->second;
} }
std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
static std::unordered_map<std::string, JitModule> map;
return map;
}
JitModule& get_jit_module( JitModule& get_jit_module(
const mlx::core::Device& device, const mlx::core::Device& device,
const std::string& name, const std::string& name,
const KernelBuilder& builder) { const KernelBuilder& builder) {
static std::unordered_map<std::string, JitModule> map; auto& map = get_jit_module_cache();
auto it = map.find(name); auto it = map.find(name);
if (it == map.end()) { if (it == map.end()) {
it = map.try_emplace(name, cu::device(device), name, builder).first; it = map.try_emplace(name, cu::device(device), name, builder).first;

View File

@ -99,6 +99,8 @@ class JitModule {
std::unordered_map<std::string, CUfunction> kernels_; std::unordered_map<std::string, CUfunction> kernels_;
}; };
std::unordered_map<std::string, JitModule>& get_jit_module_cache();
JitModule& get_jit_module( JitModule& get_jit_module(
const mlx::core::Device& device, const mlx::core::Device& device,
const std::string& name, const std::string& name,