From 254c0f56c32de9bc394652d1b3f5c6f3fb4dd44c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 26 Jul 2025 07:03:17 -0700 Subject: [PATCH] fix cuda segfault on exit --- mlx/backend/cuda/device.cpp | 4 ++++ mlx/backend/cuda/jit_module.cpp | 8 ++++++-- mlx/backend/cuda/jit_module.h | 2 ++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 41a583996..ba58b9aec 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -1,6 +1,7 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/jit_module.h" #include "mlx/backend/cuda/worker.h" #include "mlx/utils.h" @@ -54,6 +55,9 @@ Device::Device(int device) : device_(device) { CHECK_CUBLAS_ERROR(cublasLtCreate(<_)); // The cudnn handle is used by Convolution. CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_)); + + // Ensure the jit module cache is initialized + get_jit_module_cache(); } Device::~Device() { diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index 6585c452a..25db207e3 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -9,7 +9,6 @@ #include #include #include -#include #include #include @@ -330,11 +329,16 @@ CUfunction JitModule::get_kernel(const std::string& kernel_name) { return it->second; } +std::unordered_map& get_jit_module_cache() { + static std::unordered_map map; + return map; +} + JitModule& get_jit_module( const mlx::core::Device& device, const std::string& name, const KernelBuilder& builder) { - static std::unordered_map map; + auto& map = get_jit_module_cache(); auto it = map.find(name); if (it == map.end()) { it = map.try_emplace(name, cu::device(device), name, builder).first; diff --git a/mlx/backend/cuda/jit_module.h b/mlx/backend/cuda/jit_module.h index 57da7c87e..7fe3fa055 100644 --- a/mlx/backend/cuda/jit_module.h +++ b/mlx/backend/cuda/jit_module.h @@ -99,6 +99,8 @@ class JitModule { std::unordered_map kernels_; }; +std::unordered_map& get_jit_module_cache(); + JitModule& get_jit_module( const mlx::core::Device& device, const std::string& name,