diff --git a/mlx/backend/cuda/custom_kernel.cpp b/mlx/backend/cuda/custom_kernel.cpp index 560429b4a..d85e02e51 100644 --- a/mlx/backend/cuda/custom_kernel.cpp +++ b/mlx/backend/cuda/custom_kernel.cpp @@ -370,13 +370,13 @@ void CustomKernel::eval_gpu( for (const auto& t : copies) { encoder.add_temporary(t); } - auto kernel = mod.get_kernel(kernel_name); - if (shared_memory_ > 0 && shared_memory_ > 48000) { - cuFuncSetAttribute( - kernel, - CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, - shared_memory_); - } + auto kernel = + mod.get_kernel(kernel_name, [smem = shared_memory_](CUfunction kernel) { + if (smem > 0 && smem > 48000) { + cuFuncSetAttribute( + kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem); + } + }); encoder.add_kernel_node(kernel, grid, block, shared_memory_, args.args()); } diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index 9805b7cf6..367c94392 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -312,7 +312,7 @@ JitModule::JitModule( for (const auto& [name, mangled] : ptx_kernels) { CUfunction kernel; CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str())); - kernels_[name] = kernel; + kernels_[name] = std::make_pair(kernel, false); } } @@ -327,7 +327,7 @@ JitModule::JitModule( CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES}; void* values[] = {jit_log, reinterpret_cast(std::size(jit_log) - 1)}; CUresult jit_result = cuModuleLoadDataEx( - &module_, ptx.c_str(), std::size(options), options, values); + &module_, ptx.data(), std::size(options), options, values); if (jit_result != CUDA_SUCCESS) { throw std::runtime_error(fmt::format( "Failed to load compiled {} kernel: {}.", module_name, jit_log)); @@ -337,7 +337,7 @@ JitModule::JitModule( for (const auto& name : kernel_names) { CUfunction kernel; CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, name.c_str())); - kernels_[name] = kernel; + kernels_[name] = std::make_pair(kernel, false); } } @@ -345,13 +345,23 @@ JitModule::~JitModule() { CHECK_CUDA_ERROR(cuModuleUnload(module_)); } -CUfunction JitModule::get_kernel(const std::string& kernel_name) { +CUfunction JitModule::get_kernel( + const std::string& kernel_name, + std::function configure_kernel) { auto it = kernels_.find(kernel_name); if (it == kernels_.end()) { throw std::runtime_error( fmt::format("There is no kernel named {}.", kernel_name)); } - return it->second; + + // If it is the first time we run this kernel then configure it. Do it only + // once! + if (!it->second.second) { + configure_kernel(it->second.first); + it->second.second = true; + } + + return it->second.first; } std::unordered_map& get_jit_module_cache() { diff --git a/mlx/backend/cuda/jit_module.h b/mlx/backend/cuda/jit_module.h index 1de77bf26..3aa7ffa8a 100644 --- a/mlx/backend/cuda/jit_module.h +++ b/mlx/backend/cuda/jit_module.h @@ -94,11 +94,16 @@ class JitModule { JitModule(const JitModule&) = delete; JitModule& operator=(const JitModule&) = delete; - CUfunction get_kernel(const std::string& kernel_name); + CUfunction get_kernel( + const std::string& kernel_name, + std::function configure_kernel); + CUfunction get_kernel(const std::string& kernel_name) { + return get_kernel(kernel_name, [](auto k) {}); + } private: CUmodule module_{nullptr}; - std::unordered_map kernels_; + std::unordered_map> kernels_; }; std::unordered_map& get_jit_module_cache(); diff --git a/python/src/fast.cpp b/python/src/fast.cpp index b359f84bd..7105d12cb 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -388,7 +388,7 @@ void init_fast(nb::module_& parent_module) { m.def( "precompiled_custom_kernel", [](const std::string& name, - const std::string& compiled_source, + const nb::bytes compiled_source, const std::vector& inputs_, const std::vector& output_shapes, const std::vector& output_dtypes, @@ -429,7 +429,9 @@ void init_fast(nb::module_& parent_module) { return mx::fast::precompiled_custom_kernel( name, - compiled_source, + std::string( + static_cast(compiled_source.data()), + compiled_source.size()), inputs, output_shapes, output_dtypes,