diff --git a/mlx/backend/cuda/compiled.cpp b/mlx/backend/cuda/compiled.cpp index d270bee7a..2d08b8fe1 100644 --- a/mlx/backend/cuda/compiled.cpp +++ b/mlx/backend/cuda/compiled.cpp @@ -71,7 +71,7 @@ struct FusedKernelBuilder { // Index. For non contiguous kernels we create a separate index // variable per variable otherwise everyone uses `index`. os += - " IdxT index = cg::this_grid().thread_rank() * W;\n" + " IdxT index = cg::this_grid().thread_rank() * work_per_thread;\n" " if (index >= size) {\n" " return;\n" " }\n";