From 6b1b8ea91b2bd89f3adbd2b08f67639d0fa92189 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 17 Jul 2025 06:47:52 -0700 Subject: [PATCH] [CUDA] Add work per thread to compile (#2368) --- mlx/backend/cuda/compiled.cpp | 114 ++++++++++++++++++++++++++------ mlx/backend/cuda/jit_module.cpp | 9 ++- 2 files changed, 99 insertions(+), 24 deletions(-) diff --git a/mlx/backend/cuda/compiled.cpp b/mlx/backend/cuda/compiled.cpp index 2f3990b90..1aff47b89 100644 --- a/mlx/backend/cuda/compiled.cpp +++ b/mlx/backend/cuda/compiled.cpp @@ -53,9 +53,10 @@ struct FusedKernelBuilder { // Build function signature. if (contiguous) { - os += "template \n"; + os += "template \n"; } else { - os += "template \n"; + os += + "template \n"; } os += fmt::format("__global__ void {}(\n", kernel_name + name); for (size_t i = 0; i < params.size(); ++i) { @@ -67,12 +68,46 @@ struct FusedKernelBuilder { } os += ") {\n"; - // Index. + // Index. For non contiguous kernels we create a separate index + // variable per variable otherwise everyone uses `index`. os += - " IdxT index = cg::this_grid().thread_rank();\n" + " IdxT index = cg::this_grid().thread_rank() * work_per_thread;\n" " if (index >= size) {\n" " return;\n" " }\n"; + if (!contiguous) { + for (size_t i = 0; i < inputs.size(); ++i) { + const auto& x = inputs[i]; + const std::string& xname = namer.get_name(x); + if (is_scalar(x) || is_constant(i)) { + continue; + } + os += " IdxT " + xname + "_idx = 0;\n"; + } + os += " {\n"; + os += " IdxT loc = index;\n"; + os += + " #pragma unroll\n" + " for (int i = NDIM - 1; i >= 0; i--) {\n"; + for (size_t i = 0; i < inputs.size(); ++i) { + const auto& x = inputs[i]; + const std::string& xname = namer.get_name(x); + if (is_scalar(x) || is_constant(i)) { + continue; + } + os += " " + xname + "_idx += (loc \% shape[i]) * IdxT(" + xname + + "_strides[i]);\n"; + } + os += + " loc /= shape[i];\n" + " }\n" + " }\n"; + } + + // Work loop + os += + "\n" + " for (int i = 0; i < work_per_thread && index < size; i++) {\n"; // Read inputs. for (size_t i = 0; i < inputs.size(); ++i) { @@ -89,12 +124,9 @@ struct FusedKernelBuilder { } else if (contiguous) { value = fmt::format("{}[index]", xname); } else { - std::string index = fmt::format( - "elem_to_loc_nd(index, shape.data(), {}_strides.data())", - xname); - value = fmt::format("{}[{}]", xname, index); + value = fmt::format("{}[{}_idx]", xname, xname); } - os += fmt::format(" {} tmp_{} = {};\n", type, xname, value); + os += fmt::format(" {} tmp_{} = {};\n", type, xname, value); } // Write tape. @@ -113,14 +145,30 @@ struct FusedKernelBuilder { } value += fmt::format("tmp_{})", namer.get_name(x.inputs().back())); } - os += fmt::format(" {} tmp_{} = {};\n", type, xname, value); + os += fmt::format(" {} tmp_{} = {};\n", type, xname, value); } // Write output. for (const auto& x : outputs) { - os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x)); + os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x)); } + // End of work loop + os += + "\n" + " index++;\n"; + if (!contiguous) { + for (size_t i = 0; i < inputs.size(); ++i) { + const auto& x = inputs[i]; + const std::string& xname = namer.get_name(x); + if (is_scalar(x) || is_constant(i)) { + continue; + } + os += " " + xname + "_idx += " + xname + "_strides[NDIM - 1];\n"; + } + } + os += " }\n"; + os += "}\n"; } }; @@ -156,15 +204,28 @@ void Compiled::eval_gpu( builder.build("_strided", false); builder.os += "\n} // namespace mlx::core::cu\n"; // Build kernel names. - std::vector kernel_names = { - fmt::format("mlx::core::cu::{}_contiguous", lib_name()), - fmt::format("mlx::core::cu::{}_contiguous", lib_name()), - }; - for (int i = 1; i <= MAX_NDIM; ++i) { + std::vector kernel_names; + for (auto work_per_thread : std::array{1, 4}) { kernel_names.push_back(fmt::format( - "mlx::core::cu::{}_strided<{}, uint32_t>", lib_name(), i)); - kernel_names.push_back( - fmt::format("mlx::core::cu::{}_strided<{}, int64_t>", lib_name(), i)); + "mlx::core::cu::{}_contiguous", + lib_name(), + work_per_thread)); + kernel_names.push_back(fmt::format( + "mlx::core::cu::{}_contiguous", + lib_name(), + work_per_thread)); + for (int i = 1; i <= MAX_NDIM; ++i) { + kernel_names.push_back(fmt::format( + "mlx::core::cu::{}_strided<{}, uint32_t, {}>", + lib_name(), + i, + work_per_thread)); + kernel_names.push_back(fmt::format( + "mlx::core::cu::{}_strided<{}, int64_t, {}>", + lib_name(), + i, + work_per_thread)); + } } return std::make_pair(std::move(builder.os), std::move(kernel_names)); }); @@ -207,13 +268,21 @@ void Compiled::eval_gpu( args.append(outputs[0].data_size()); } + // Choose work per thread + int work_per_thread = 4; + if (!contiguous && shape.back() % work_per_thread != 0) { + work_per_thread = 1; + } + // Launch kernel. const char* index_type = large ? "int64_t" : "uint32_t"; std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name()); if (contiguous) { - kernel_name += fmt::format("_contiguous<{}>", index_type); + kernel_name += + fmt::format("_contiguous<{}, {}>", index_type, work_per_thread); } else { - kernel_name += fmt::format("_strided<{}, {}>", shape.size(), index_type); + kernel_name += fmt::format( + "_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread); } auto& encoder = cu::get_command_encoder(s); for (const auto& in : inputs) { @@ -224,7 +293,8 @@ void Compiled::eval_gpu( } auto kernel = mod.get_kernel(kernel_name); - auto [num_blocks, block_dims] = get_launch_args(kernel, outputs[0], large); + auto [num_blocks, block_dims] = + get_launch_args(kernel, outputs[0], large, work_per_thread); encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); } diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index 343db902e..a9e5631de 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -121,7 +121,8 @@ void write_cached_ptx( const std::filesystem::path& cache_dir, const std::string& module_name, const std::vector& ptx, - const std::vector>& ptx_kernels) { + const std::vector>& ptx_kernels, + const std::string& source_code) { if (cache_dir.empty()) { return; } @@ -134,6 +135,9 @@ void write_cached_ptx( for (const auto& [name, mangled] : ptx_kernels) { txt_file << name << "\t" << mangled << std::endl; } + + std::ofstream source_file(cache_dir / (module_name + ".cu")); + source_file << source_code; } // Return if |device|'s version is not newer than |major|.|minor| version. @@ -272,7 +276,8 @@ JitModule::JitModule( } else { CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data())); } - write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels); + write_cached_ptx( + ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code); } // Load module.