Fix work-per-thread for strided kernels

This commit is contained in:
Angelos Katharopoulos
2025-07-14 23:57:23 -07:00
parent e74e593948
commit b24f6f64fd

View File

@@ -205,7 +205,7 @@ void Compiled::eval_gpu(
builder.os += "\n} // namespace mlx::core::cu\n"; builder.os += "\n} // namespace mlx::core::cu\n";
// Build kernel names. // Build kernel names.
std::vector<std::string> kernel_names; std::vector<std::string> kernel_names;
for (auto work_per_thread : std::array<int, 1>{4}) { for (auto work_per_thread : std::array<int, 2>{1, 4}) {
kernel_names.push_back(fmt::format( kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_contiguous<uint32_t, {}>", "mlx::core::cu::{}_contiguous<uint32_t, {}>",
lib_name(), lib_name(),
@@ -268,10 +268,15 @@ void Compiled::eval_gpu(
args.append<uint32_t>(outputs[0].data_size()); args.append<uint32_t>(outputs[0].data_size());
} }
// Choose work per thread
int work_per_thread = 4;
if (!contiguous && shape.back() % work_per_thread != 0) {
work_per_thread = 1;
}
// Launch kernel. // Launch kernel.
const char* index_type = large ? "int64_t" : "uint32_t"; const char* index_type = large ? "int64_t" : "uint32_t";
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name()); std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
int work_per_thread = 4;
if (contiguous) { if (contiguous) {
kernel_name += kernel_name +=
fmt::format("_contiguous<{}, {}>", index_type, work_per_thread); fmt::format("_contiguous<{}, {}>", index_type, work_per_thread);
@@ -288,8 +293,8 @@ void Compiled::eval_gpu(
} }
auto kernel = mod.get_kernel(kernel_name); auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(kernel, outputs[0], large); auto [num_blocks, block_dims] =
num_blocks.x = (num_blocks.x + work_per_thread - 1) / work_per_thread; get_launch_args(kernel, outputs[0], large, work_per_thread);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
} }