Fix the template arg name to not clash with inputs

This commit is contained in:
Angelos Katharopoulos
2025-07-14 00:03:21 -07:00
parent 8b77aa9b8d
commit bb341d85b5

View File

@@ -53,9 +53,10 @@ struct FusedKernelBuilder {
// Build function signature. // Build function signature.
if (contiguous) { if (contiguous) {
os += "template <typename IdxT = uint32_t, int W = 1>\n"; os += "template <typename IdxT = uint32_t, int work_per_thread = 1>\n";
} else { } else {
os += "template <int NDIM, typename IdxT = uint32_t, int W = 1>\n"; os +=
"template <int NDIM, typename IdxT = uint32_t, int work_per_thread = 1>\n";
} }
os += fmt::format("__global__ void {}(\n", kernel_name + name); os += fmt::format("__global__ void {}(\n", kernel_name + name);
for (size_t i = 0; i < params.size(); ++i) { for (size_t i = 0; i < params.size(); ++i) {
@@ -106,7 +107,7 @@ struct FusedKernelBuilder {
// Work loop // Work loop
os += os +=
"\n" "\n"
" for (int i = 0; i < W && index < size; i++) {\n"; " for (int i = 0; i < work_per_thread && index < size; i++) {\n";
// Read inputs. // Read inputs.
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {