[CUDA] Add work per thread to compile (#2368)

This commit is contained in:
Angelos Katharopoulos
2025-07-17 06:47:52 -07:00
committed by GitHub
parent b2273733ea
commit 6b1b8ea91b
2 changed files with 99 additions and 24 deletions

View File

@@ -121,7 +121,8 @@ void write_cached_ptx(
const std::filesystem::path& cache_dir,
const std::string& module_name,
const std::vector<char>& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
const std::string& source_code) {
if (cache_dir.empty()) {
return;
}
@@ -134,6 +135,9 @@ void write_cached_ptx(
for (const auto& [name, mangled] : ptx_kernels) {
txt_file << name << "\t" << mangled << std::endl;
}
std::ofstream source_file(cache_dir / (module_name + ".cu"));
source_file << source_code;
}
// Return if |device|'s version is not newer than |major|.|minor| version.
@@ -272,7 +276,8 @@ JitModule::JitModule(
} else {
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
}
write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels);
write_cached_ptx(
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
}
// Load module.