mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Add work per thread to compile (#2368)
This commit is contained in:
committed by
GitHub
parent
b2273733ea
commit
6b1b8ea91b
@@ -121,7 +121,8 @@ void write_cached_ptx(
|
||||
const std::filesystem::path& cache_dir,
|
||||
const std::string& module_name,
|
||||
const std::vector<char>& ptx,
|
||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
|
||||
const std::string& source_code) {
|
||||
if (cache_dir.empty()) {
|
||||
return;
|
||||
}
|
||||
@@ -134,6 +135,9 @@ void write_cached_ptx(
|
||||
for (const auto& [name, mangled] : ptx_kernels) {
|
||||
txt_file << name << "\t" << mangled << std::endl;
|
||||
}
|
||||
|
||||
std::ofstream source_file(cache_dir / (module_name + ".cu"));
|
||||
source_file << source_code;
|
||||
}
|
||||
|
||||
// Return if |device|'s version is not newer than |major|.|minor| version.
|
||||
@@ -272,7 +276,8 @@ JitModule::JitModule(
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
||||
}
|
||||
write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels);
|
||||
write_cached_ptx(
|
||||
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
|
||||
}
|
||||
|
||||
// Load module.
|
||||
|
||||
Reference in New Issue
Block a user