diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 41a583996e..38bc73f7fd 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -1,6 +1,7 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/jit_module.h" #include "mlx/backend/cuda/worker.h" #include "mlx/utils.h" @@ -54,6 +55,10 @@ Device::Device(int device) : device_(device) { CHECK_CUBLAS_ERROR(cublasLtCreate(<_)); // The cudnn handle is used by Convolution. CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_)); + + // Initialize the jit module cache here ensures it is not + // unloaded before any evaluation is done + get_jit_module_cache(); } Device::~Device() { diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index 6585c452a3..25db207e33 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -9,7 +9,6 @@ #include #include #include -#include #include #include @@ -330,11 +329,16 @@ CUfunction JitModule::get_kernel(const std::string& kernel_name) { return it->second; } +std::unordered_map& get_jit_module_cache() { + static std::unordered_map map; + return map; +} + JitModule& get_jit_module( const mlx::core::Device& device, const std::string& name, const KernelBuilder& builder) { - static std::unordered_map map; + auto& map = get_jit_module_cache(); auto it = map.find(name); if (it == map.end()) { it = map.try_emplace(name, cu::device(device), name, builder).first; diff --git a/mlx/backend/cuda/jit_module.h b/mlx/backend/cuda/jit_module.h index 57da7c87ee..7fe3fa055f 100644 --- a/mlx/backend/cuda/jit_module.h +++ b/mlx/backend/cuda/jit_module.h @@ -99,6 +99,8 @@ class JitModule { std::unordered_map kernels_; }; +std::unordered_map& get_jit_module_cache(); + JitModule& get_jit_module( const mlx::core::Device& device, const std::string& name,