This commit is contained in:
Angelos Katharopoulos 2025-08-15 14:08:03 -07:00
parent 3b94e37270
commit 055c1ca929
4 changed files with 33 additions and 16 deletions

View File

@ -370,13 +370,13 @@ void CustomKernel::eval_gpu(
for (const auto& t : copies) {
encoder.add_temporary(t);
}
auto kernel = mod.get_kernel(kernel_name);
if (shared_memory_ > 0 && shared_memory_ > 48000) {
cuFuncSetAttribute(
kernel,
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_memory_);
}
auto kernel =
mod.get_kernel(kernel_name, [smem = shared_memory_](CUfunction kernel) {
if (smem > 0 && smem > 48000) {
cuFuncSetAttribute(
kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem);
}
});
encoder.add_kernel_node(kernel, grid, block, shared_memory_, args.args());
}

View File

@ -312,7 +312,7 @@ JitModule::JitModule(
for (const auto& [name, mangled] : ptx_kernels) {
CUfunction kernel;
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
kernels_[name] = kernel;
kernels_[name] = std::make_pair(kernel, false);
}
}
@ -327,7 +327,7 @@ JitModule::JitModule(
CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES};
void* values[] = {jit_log, reinterpret_cast<void*>(std::size(jit_log) - 1)};
CUresult jit_result = cuModuleLoadDataEx(
&module_, ptx.c_str(), std::size(options), options, values);
&module_, ptx.data(), std::size(options), options, values);
if (jit_result != CUDA_SUCCESS) {
throw std::runtime_error(fmt::format(
"Failed to load compiled {} kernel: {}.", module_name, jit_log));
@ -337,7 +337,7 @@ JitModule::JitModule(
for (const auto& name : kernel_names) {
CUfunction kernel;
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, name.c_str()));
kernels_[name] = kernel;
kernels_[name] = std::make_pair(kernel, false);
}
}
@ -345,13 +345,23 @@ JitModule::~JitModule() {
CHECK_CUDA_ERROR(cuModuleUnload(module_));
}
CUfunction JitModule::get_kernel(const std::string& kernel_name) {
CUfunction JitModule::get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel) {
auto it = kernels_.find(kernel_name);
if (it == kernels_.end()) {
throw std::runtime_error(
fmt::format("There is no kernel named {}.", kernel_name));
}
return it->second;
// If it is the first time we run this kernel then configure it. Do it only
// once!
if (!it->second.second) {
configure_kernel(it->second.first);
it->second.second = true;
}
return it->second.first;
}
std::unordered_map<std::string, JitModule>& get_jit_module_cache() {

View File

@ -94,11 +94,16 @@ class JitModule {
JitModule(const JitModule&) = delete;
JitModule& operator=(const JitModule&) = delete;
CUfunction get_kernel(const std::string& kernel_name);
CUfunction get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel);
CUfunction get_kernel(const std::string& kernel_name) {
return get_kernel(kernel_name, [](auto k) {});
}
private:
CUmodule module_{nullptr};
std::unordered_map<std::string, CUfunction> kernels_;
std::unordered_map<std::string, std::pair<CUfunction, bool>> kernels_;
};
std::unordered_map<std::string, JitModule>& get_jit_module_cache();

View File

@ -388,7 +388,7 @@ void init_fast(nb::module_& parent_module) {
m.def(
"precompiled_custom_kernel",
[](const std::string& name,
const std::string& compiled_source,
const nb::bytes compiled_source,
const std::vector<ScalarOrArray>& inputs_,
const std::vector<mx::Shape>& output_shapes,
const std::vector<mx::Dtype>& output_dtypes,
@ -429,7 +429,9 @@ void init_fast(nb::module_& parent_module) {
return mx::fast::precompiled_custom_kernel(
name,
compiled_source,
std::string(
static_cast<const char*>(compiled_source.data()),
compiled_source.size()),
inputs,
output_shapes,
output_dtypes,