Compare commits

..

1 Commits

Author SHA1 Message Date
Anastasiia Filippova
2c9dd955fa
Merge e6ae350999 into cad5c0241c 2025-06-17 15:12:02 -04:00

View File

@ -37,46 +37,36 @@ void check_cu_error(const char* name, CUresult err) {
} }
// Return the location of the CUDA toolkit. // Return the location of the CUDA toolkit.
const std::string& cuda_home() { const char* cuda_home() {
static std::string home = []() -> std::string { const char* home = std::getenv("CUDA_HOME");
const char* home = std::getenv("CUDA_HOME"); if (home) {
if (home) { return home;
return home; }
} home = std::getenv("CUDA_PATH");
home = std::getenv("CUDA_PATH"); if (home) {
if (home) { return home;
return home; }
}
#if defined(__linux__) #if defined(__linux__)
home = "/usr/local/cuda"; home = "/usr/local/cuda";
if (std::filesystem::exists(home)) { if (std::filesystem::exists(home)) {
return home; return home;
} }
#endif #endif
throw std::runtime_error( throw std::runtime_error(
"Environment variable CUDA_HOME or CUDA_PATH is not set."); "Environment variable CUDA_HOME or CUDA_PATH is not set.");
}();
return home;
} }
// Get the cache directory for storing compiled results. // Get the cache directory for storing compiled results.
const std::filesystem::path& ptx_cache_dir() { bool get_ptx_cache_dir(std::filesystem::path* result) {
static std::filesystem::path cache = []() -> std::filesystem::path { auto path = std::filesystem::temp_directory_path() / "mlx" / "ptx";
std::filesystem::path cache; if (!std::filesystem::is_directory(path)) {
if (auto c = std::getenv("MLX_PTX_CACHE"); c) { std::error_code error;
cache = c; if (!std::filesystem::create_directories(path, error)) {
} else { return false;
cache = std::filesystem::temp_directory_path() / "mlx" / "ptx";
} }
if (!std::filesystem::exists(cache)) { }
std::error_code error; *result = path;
if (!std::filesystem::create_directories(cache, error)) { return true;
return std::filesystem::path();
}
}
return cache;
}();
return cache;
} }
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|. // Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
@ -85,10 +75,6 @@ bool read_cached_ptx(
const std::string& module_name, const std::string& module_name,
std::vector<char>* ptx, std::vector<char>* ptx,
std::vector<std::pair<std::string, std::string>>* ptx_kernels) { std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
if (cache_dir.empty()) {
return false;
}
auto ptx_path = cache_dir / (module_name + ".ptx"); auto ptx_path = cache_dir / (module_name + ".ptx");
std::error_code error; std::error_code error;
auto ptx_size = std::filesystem::file_size(ptx_path, error); auto ptx_size = std::filesystem::file_size(ptx_path, error);
@ -119,10 +105,6 @@ void write_cached_ptx(
const std::string& module_name, const std::string& module_name,
const std::vector<char>& ptx, const std::vector<char>& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) { const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
if (cache_dir.empty()) {
return;
}
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary); std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary);
if (!ptx.empty()) { if (!ptx.empty()) {
ptx_file.write(&ptx.front(), ptx.size()); ptx_file.write(&ptx.front(), ptx.size());
@ -202,9 +184,11 @@ JitModule::JitModule(
const std::string& module_name, const std::string& module_name,
const KernelBuilder& builder) { const KernelBuilder& builder) {
// Check cache. // Check cache.
std::filesystem::path cache_dir;
std::vector<char> ptx; std::vector<char> ptx;
std::vector<std::pair<std::string, std::string>> ptx_kernels; std::vector<std::pair<std::string, std::string>> ptx_kernels;
if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) { if (!get_ptx_cache_dir(&cache_dir) ||
!read_cached_ptx(cache_dir, module_name, &ptx, &ptx_kernels)) {
// Create program. // Create program.
auto [source_code, kernel_names] = builder(); auto [source_code, kernel_names] = builder();
nvrtcProgram prog; nvrtcProgram prog;
@ -262,7 +246,7 @@ JitModule::JitModule(
} else { } else {
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data())); CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
} }
write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels); write_cached_ptx(cache_dir, module_name, ptx, ptx_kernels);
} }
// Load module. // Load module.