mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-05 07:34:42 +08:00
Fix the template arg name to not clash with inputs
This commit is contained in:
@@ -53,9 +53,10 @@ struct FusedKernelBuilder {
|
|||||||
|
|
||||||
// Build function signature.
|
// Build function signature.
|
||||||
if (contiguous) {
|
if (contiguous) {
|
||||||
os += "template <typename IdxT = uint32_t, int W = 1>\n";
|
os += "template <typename IdxT = uint32_t, int work_per_thread = 1>\n";
|
||||||
} else {
|
} else {
|
||||||
os += "template <int NDIM, typename IdxT = uint32_t, int W = 1>\n";
|
os +=
|
||||||
|
"template <int NDIM, typename IdxT = uint32_t, int work_per_thread = 1>\n";
|
||||||
}
|
}
|
||||||
os += fmt::format("__global__ void {}(\n", kernel_name + name);
|
os += fmt::format("__global__ void {}(\n", kernel_name + name);
|
||||||
for (size_t i = 0; i < params.size(); ++i) {
|
for (size_t i = 0; i < params.size(); ++i) {
|
||||||
@@ -106,7 +107,7 @@ struct FusedKernelBuilder {
|
|||||||
// Work loop
|
// Work loop
|
||||||
os +=
|
os +=
|
||||||
"\n"
|
"\n"
|
||||||
" for (int i = 0; i < W && index < size; i++) {\n";
|
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
|
||||||
|
|
||||||
// Read inputs.
|
// Read inputs.
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
Reference in New Issue
Block a user