This commit is contained in:
Angelos Katharopoulos 2025-08-15 14:08:03 -07:00
parent 3b94e37270
commit 055c1ca929
4 changed files with 33 additions and 16 deletions

View File

@ -370,13 +370,13 @@ void CustomKernel::eval_gpu(
for (const auto& t : copies) { for (const auto& t : copies) {
encoder.add_temporary(t); encoder.add_temporary(t);
} }
auto kernel = mod.get_kernel(kernel_name); auto kernel =
if (shared_memory_ > 0 && shared_memory_ > 48000) { mod.get_kernel(kernel_name, [smem = shared_memory_](CUfunction kernel) {
if (smem > 0 && smem > 48000) {
cuFuncSetAttribute( cuFuncSetAttribute(
kernel, kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem);
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_memory_);
} }
});
encoder.add_kernel_node(kernel, grid, block, shared_memory_, args.args()); encoder.add_kernel_node(kernel, grid, block, shared_memory_, args.args());
} }

View File

@ -312,7 +312,7 @@ JitModule::JitModule(
for (const auto& [name, mangled] : ptx_kernels) { for (const auto& [name, mangled] : ptx_kernels) {
CUfunction kernel; CUfunction kernel;
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str())); CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
kernels_[name] = kernel; kernels_[name] = std::make_pair(kernel, false);
} }
} }
@ -327,7 +327,7 @@ JitModule::JitModule(
CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES}; CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES};
void* values[] = {jit_log, reinterpret_cast<void*>(std::size(jit_log) - 1)}; void* values[] = {jit_log, reinterpret_cast<void*>(std::size(jit_log) - 1)};
CUresult jit_result = cuModuleLoadDataEx( CUresult jit_result = cuModuleLoadDataEx(
&module_, ptx.c_str(), std::size(options), options, values); &module_, ptx.data(), std::size(options), options, values);
if (jit_result != CUDA_SUCCESS) { if (jit_result != CUDA_SUCCESS) {
throw std::runtime_error(fmt::format( throw std::runtime_error(fmt::format(
"Failed to load compiled {} kernel: {}.", module_name, jit_log)); "Failed to load compiled {} kernel: {}.", module_name, jit_log));
@ -337,7 +337,7 @@ JitModule::JitModule(
for (const auto& name : kernel_names) { for (const auto& name : kernel_names) {
CUfunction kernel; CUfunction kernel;
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, name.c_str())); CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, name.c_str()));
kernels_[name] = kernel; kernels_[name] = std::make_pair(kernel, false);
} }
} }
@ -345,13 +345,23 @@ JitModule::~JitModule() {
CHECK_CUDA_ERROR(cuModuleUnload(module_)); CHECK_CUDA_ERROR(cuModuleUnload(module_));
} }
CUfunction JitModule::get_kernel(const std::string& kernel_name) { CUfunction JitModule::get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel) {
auto it = kernels_.find(kernel_name); auto it = kernels_.find(kernel_name);
if (it == kernels_.end()) { if (it == kernels_.end()) {
throw std::runtime_error( throw std::runtime_error(
fmt::format("There is no kernel named {}.", kernel_name)); fmt::format("There is no kernel named {}.", kernel_name));
} }
return it->second;
// If it is the first time we run this kernel then configure it. Do it only
// once!
if (!it->second.second) {
configure_kernel(it->second.first);
it->second.second = true;
}
return it->second.first;
} }
std::unordered_map<std::string, JitModule>& get_jit_module_cache() { std::unordered_map<std::string, JitModule>& get_jit_module_cache() {

View File

@ -94,11 +94,16 @@ class JitModule {
JitModule(const JitModule&) = delete; JitModule(const JitModule&) = delete;
JitModule& operator=(const JitModule&) = delete; JitModule& operator=(const JitModule&) = delete;
CUfunction get_kernel(const std::string& kernel_name); CUfunction get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel);
CUfunction get_kernel(const std::string& kernel_name) {
return get_kernel(kernel_name, [](auto k) {});
}
private: private:
CUmodule module_{nullptr}; CUmodule module_{nullptr};
std::unordered_map<std::string, CUfunction> kernels_; std::unordered_map<std::string, std::pair<CUfunction, bool>> kernels_;
}; };
std::unordered_map<std::string, JitModule>& get_jit_module_cache(); std::unordered_map<std::string, JitModule>& get_jit_module_cache();

View File

@ -388,7 +388,7 @@ void init_fast(nb::module_& parent_module) {
m.def( m.def(
"precompiled_custom_kernel", "precompiled_custom_kernel",
[](const std::string& name, [](const std::string& name,
const std::string& compiled_source, const nb::bytes compiled_source,
const std::vector<ScalarOrArray>& inputs_, const std::vector<ScalarOrArray>& inputs_,
const std::vector<mx::Shape>& output_shapes, const std::vector<mx::Shape>& output_shapes,
const std::vector<mx::Dtype>& output_dtypes, const std::vector<mx::Dtype>& output_dtypes,
@ -429,7 +429,9 @@ void init_fast(nb::module_& parent_module) {
return mx::fast::precompiled_custom_kernel( return mx::fast::precompiled_custom_kernel(
name, name,
compiled_source, std::string(
static_cast<const char*>(compiled_source.data()),
compiled_source.size()),
inputs, inputs,
output_shapes, output_shapes,
output_dtypes, output_dtypes,