Fix the template arg name to not clash with inputs

This commit is contained in:
Angelos Katharopoulos
2025-07-14 00:03:21 -07:00
parent 8b77aa9b8d
commit bb341d85b5

View File

@@ -53,9 +53,10 @@ struct FusedKernelBuilder {
// Build function signature.
if (contiguous) {
os += "template <typename IdxT = uint32_t, int W = 1>\n";
os += "template <typename IdxT = uint32_t, int work_per_thread = 1>\n";
} else {
os += "template <int NDIM, typename IdxT = uint32_t, int W = 1>\n";
os +=
"template <int NDIM, typename IdxT = uint32_t, int work_per_thread = 1>\n";
}
os += fmt::format("__global__ void {}(\n", kernel_name + name);
for (size_t i = 0; i < params.size(); ++i) {
@@ -106,7 +107,7 @@ struct FusedKernelBuilder {
// Work loop
os +=
"\n"
" for (int i = 0; i < W && index < size; i++) {\n";
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
// Read inputs.
for (size_t i = 0; i < inputs.size(); ++i) {