diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index 9a55bf902..2801e4a67 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -99,6 +99,30 @@ const std::filesystem::path& ptx_cache_dir() { return cache; } +std::filesystem::path get_ptx_path( + const std::filesystem::path& cache_dir, + const std::string& module_name) { +#ifdef _WIN32 + constexpr int max_file_name_length = 140; +#else + constexpr int max_file_name_length = 245; +#endif + + if (module_name.size() <= max_file_name_length) { + return cache_dir / (module_name + ".ptx"); + } + + auto ptx_path = cache_dir; + int offset = 0; + while (module_name.size() - offset > max_file_name_length) { + ptx_path /= module_name.substr(offset, max_file_name_length); + offset += max_file_name_length; + } + ptx_path /= module_name.substr(offset) + ".ptx"; + + return ptx_path; +} + // Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|. bool read_cached_ptx( const std::filesystem::path& cache_dir, @@ -109,7 +133,7 @@ bool read_cached_ptx( return false; } - auto ptx_path = cache_dir / (module_name + ".ptx"); + auto ptx_path = get_ptx_path(cache_dir, module_name); std::error_code error; auto ptx_size = std::filesystem::file_size(ptx_path, error); if (error) { @@ -122,7 +146,7 @@ bool read_cached_ptx( ptx.resize(ptx_size); ptx_file.read(ptx.data(), ptx_size); - std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary); + std::ifstream txt_file(ptx_path.replace_extension(".txt"), std::ios::binary); std::string line; while (std::getline(txt_file, line)) { auto tab = line.find('\t'); @@ -144,16 +168,26 @@ void write_cached_ptx( return; } - std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary); + auto ptx_path = get_ptx_path(cache_dir, module_name); + + // Ensure that the directory exists + auto parent = ptx_path.parent_path(); + if (parent != cache_dir) { + std::filesystem::create_directories(parent); + } + + // Write the compiled code and mangled names + std::ofstream ptx_file(ptx_path, std::ios::binary); if (!ptx.empty()) { ptx_file.write(&ptx.front(), ptx.size()); } - std::ofstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary); + std::ofstream txt_file(ptx_path.replace_extension(".txt"), std::ios::binary); for (const auto& [name, mangled] : ptx_kernels) { txt_file << name << "\t" << mangled << std::endl; } - std::ofstream source_file(cache_dir / (module_name + ".cu")); + // Write the generated code + std::ofstream source_file(ptx_path.replace_extension(".cu")); source_file << source_code; }