diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index 5bc56b25e..e6dbd35da 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -2,6 +2,7 @@ #include "mlx/backend/cuda/jit_module.h" #include "mlx/backend/cuda/device.h" +#include "mlx/version.h" #include "cuda_jit_sources.h" @@ -53,10 +54,11 @@ const std::string& cuda_home() { const std::filesystem::path& ptx_cache_dir() { static std::filesystem::path cache = []() -> std::filesystem::path { std::filesystem::path cache; - if (auto c = std::getenv("MLX_PTX_CACHE"); c) { + if (auto c = std::getenv("MLX_PTX_CACHE_DIR"); c) { cache = c; } else { - cache = std::filesystem::temp_directory_path() / "mlx" / "ptx"; + cache = + std::filesystem::temp_directory_path() / "mlx" / version() / "ptx"; } if (!std::filesystem::exists(cache)) { std::error_code error;