fix for max block dim (#2631)

This commit is contained in:
Awni Hannun
2025-09-29 08:59:25 -07:00
committed by GitHub
parent e76a8dd5c5
commit dc371ae7a5
7 changed files with 67 additions and 21 deletions

View File

@@ -297,7 +297,8 @@ void load_module(
const std::string& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
CUmodule& module_,
std::unordered_map<std::string, std::pair<CUfunction, bool>>& kernels) {
std::unordered_map<std::string, std::tuple<CUfunction, bool, uint>>&
kernels) {
// Load module.
char jit_log[4089] = {};
CUjit_option options[] = {
@@ -314,7 +315,7 @@ void load_module(
for (const auto& [name, mangled] : ptx_kernels) {
CUfunction kernel;
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
kernels[name] = std::make_pair(kernel, false);
kernels[name] = std::make_tuple(kernel, false, 0);
}
}
@@ -358,7 +359,7 @@ JitModule::~JitModule() {
CHECK_CUDA_ERROR(cuModuleUnload(module_));
}
CUfunction JitModule::get_kernel(
std::pair<CUfunction, uint> JitModule::get_kernel_and_dims(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel) {
auto it = kernels_.find(kernel_name);
@@ -369,14 +370,22 @@ CUfunction JitModule::get_kernel(
// If it is the first time we run this kernel then configure it. Do it only
// once!
if (!it->second.second) {
auto kernel = std::get<0>(it->second);
if (!std::get<1>(it->second)) {
if (configure_kernel) {
configure_kernel(it->second.first);
configure_kernel(kernel);
}
it->second.second = true;
std::get<1>(it->second) = true;
std::get<2>(it->second) = max_occupancy_block_dim(kernel);
}
return it->second.first;
return {kernel, std::get<2>(it->second)};
}
CUfunction JitModule::get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel) {
return get_kernel_and_dims(kernel_name, std::move(configure_kernel)).first;
}
std::unordered_map<std::string, JitModule>& get_jit_module_cache() {