diff --git a/mlx/backend/cuda/compiled.cpp b/mlx/backend/cuda/compiled.cpp index 2d08b8fe1..1aff47b89 100644 --- a/mlx/backend/cuda/compiled.cpp +++ b/mlx/backend/cuda/compiled.cpp @@ -205,7 +205,7 @@ void Compiled::eval_gpu( builder.os += "\n} // namespace mlx::core::cu\n"; // Build kernel names. std::vector kernel_names; - for (auto work_per_thread : std::array{4}) { + for (auto work_per_thread : std::array{1, 4}) { kernel_names.push_back(fmt::format( "mlx::core::cu::{}_contiguous", lib_name(), @@ -268,10 +268,15 @@ void Compiled::eval_gpu( args.append(outputs[0].data_size()); } + // Choose work per thread + int work_per_thread = 4; + if (!contiguous && shape.back() % work_per_thread != 0) { + work_per_thread = 1; + } + // Launch kernel. const char* index_type = large ? "int64_t" : "uint32_t"; std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name()); - int work_per_thread = 4; if (contiguous) { kernel_name += fmt::format("_contiguous<{}, {}>", index_type, work_per_thread); @@ -288,8 +293,8 @@ void Compiled::eval_gpu( } auto kernel = mod.get_kernel(kernel_name); - auto [num_blocks, block_dims] = get_launch_args(kernel, outputs[0], large); - num_blocks.x = (num_blocks.x + work_per_thread - 1) / work_per_thread; + auto [num_blocks, block_dims] = + get_launch_args(kernel, outputs[0], large, work_per_thread); encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); }