mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Fix the template arg name to not clash with inputs
This commit is contained in:
@@ -53,9 +53,10 @@ struct FusedKernelBuilder {
|
||||
|
||||
// Build function signature.
|
||||
if (contiguous) {
|
||||
os += "template <typename IdxT = uint32_t, int W = 1>\n";
|
||||
os += "template <typename IdxT = uint32_t, int work_per_thread = 1>\n";
|
||||
} else {
|
||||
os += "template <int NDIM, typename IdxT = uint32_t, int W = 1>\n";
|
||||
os +=
|
||||
"template <int NDIM, typename IdxT = uint32_t, int work_per_thread = 1>\n";
|
||||
}
|
||||
os += fmt::format("__global__ void {}(\n", kernel_name + name);
|
||||
for (size_t i = 0; i < params.size(); ++i) {
|
||||
@@ -106,7 +107,7 @@ struct FusedKernelBuilder {
|
||||
// Work loop
|
||||
os +=
|
||||
"\n"
|
||||
" for (int i = 0; i < W && index < size; i++) {\n";
|
||||
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
|
||||
|
||||
// Read inputs.
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
|
Reference in New Issue
Block a user