From b3d7b8537610c2db2b1875deb5b1d230c47e8b7b Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 17 Jun 2025 23:55:56 -0700 Subject: [PATCH] Make ptx cache settable by environment variable (#2304) --- mlx/backend/cuda/jit_module.cpp | 72 ++++++++++++++++++++------------- 1 file changed, 44 insertions(+), 28 deletions(-) diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index 8a033523c..af8f7dc75 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -37,36 +37,46 @@ void check_cu_error(const char* name, CUresult err) { } // Return the location of the CUDA toolkit. -const char* cuda_home() { - const char* home = std::getenv("CUDA_HOME"); - if (home) { - return home; - } - home = std::getenv("CUDA_PATH"); - if (home) { - return home; - } +const std::string& cuda_home() { + static std::string home = []() -> std::string { + const char* home = std::getenv("CUDA_HOME"); + if (home) { + return home; + } + home = std::getenv("CUDA_PATH"); + if (home) { + return home; + } #if defined(__linux__) - home = "/usr/local/cuda"; - if (std::filesystem::exists(home)) { - return home; - } + home = "/usr/local/cuda"; + if (std::filesystem::exists(home)) { + return home; + } #endif - throw std::runtime_error( - "Environment variable CUDA_HOME or CUDA_PATH is not set."); + throw std::runtime_error( + "Environment variable CUDA_HOME or CUDA_PATH is not set."); + }(); + return home; } // Get the cache directory for storing compiled results. -bool get_ptx_cache_dir(std::filesystem::path* result) { - auto path = std::filesystem::temp_directory_path() / "mlx" / "ptx"; - if (!std::filesystem::is_directory(path)) { - std::error_code error; - if (!std::filesystem::create_directories(path, error)) { - return false; +const std::filesystem::path& ptx_cache_dir() { + static std::filesystem::path cache = []() -> std::filesystem::path { + std::filesystem::path cache; + if (auto c = std::getenv("MLX_PTX_CACHE"); c) { + cache = c; + } else { + cache = std::filesystem::temp_directory_path() / "mlx" / "ptx"; } - } - *result = path; - return true; + if (!std::filesystem::exists(cache)) { + std::error_code error; + if (!std::filesystem::create_directories(cache, error)) { + return std::filesystem::path(); + } + } + return cache; + }(); + return cache; } // Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|. @@ -75,6 +85,10 @@ bool read_cached_ptx( const std::string& module_name, std::vector* ptx, std::vector>* ptx_kernels) { + if (cache_dir.empty()) { + return false; + } + auto ptx_path = cache_dir / (module_name + ".ptx"); std::error_code error; auto ptx_size = std::filesystem::file_size(ptx_path, error); @@ -105,6 +119,10 @@ void write_cached_ptx( const std::string& module_name, const std::vector& ptx, const std::vector>& ptx_kernels) { + if (cache_dir.empty()) { + return; + } + std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary); if (!ptx.empty()) { ptx_file.write(&ptx.front(), ptx.size()); @@ -184,11 +202,9 @@ JitModule::JitModule( const std::string& module_name, const KernelBuilder& builder) { // Check cache. - std::filesystem::path cache_dir; std::vector ptx; std::vector> ptx_kernels; - if (!get_ptx_cache_dir(&cache_dir) || - !read_cached_ptx(cache_dir, module_name, &ptx, &ptx_kernels)) { + if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) { // Create program. auto [source_code, kernel_names] = builder(); nvrtcProgram prog; @@ -246,7 +262,7 @@ JitModule::JitModule( } else { CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data())); } - write_cached_ptx(cache_dir, module_name, ptx, ptx_kernels); + write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels); } // Load module.