Fix work-per-thread for strided kernels

This commit is contained in:
Angelos Katharopoulos
2025-07-14 23:57:23 -07:00
parent e74e593948
commit b24f6f64fd

View File

@@ -205,7 +205,7 @@ void Compiled::eval_gpu(
builder.os += "\n} // namespace mlx::core::cu\n";
// Build kernel names.
std::vector<std::string> kernel_names;
for (auto work_per_thread : std::array<int, 1>{4}) {
for (auto work_per_thread : std::array<int, 2>{1, 4}) {
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
lib_name(),
@@ -268,10 +268,15 @@ void Compiled::eval_gpu(
args.append<uint32_t>(outputs[0].data_size());
}
// Choose work per thread
int work_per_thread = 4;
if (!contiguous && shape.back() % work_per_thread != 0) {
work_per_thread = 1;
}
// Launch kernel.
const char* index_type = large ? "int64_t" : "uint32_t";
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
int work_per_thread = 4;
if (contiguous) {
kernel_name +=
fmt::format("_contiguous<{}, {}>", index_type, work_per_thread);
@@ -288,8 +293,8 @@ void Compiled::eval_gpu(
}
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(kernel, outputs[0], large);
num_blocks.x = (num_blocks.x + work_per_thread - 1) / work_per_thread;
auto [num_blocks, block_dims] =
get_launch_args(kernel, outputs[0], large, work_per_thread);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
}