mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Fix work-per-thread for strided kernels
This commit is contained in:
@@ -205,7 +205,7 @@ void Compiled::eval_gpu(
|
||||
builder.os += "\n} // namespace mlx::core::cu\n";
|
||||
// Build kernel names.
|
||||
std::vector<std::string> kernel_names;
|
||||
for (auto work_per_thread : std::array<int, 1>{4}) {
|
||||
for (auto work_per_thread : std::array<int, 2>{1, 4}) {
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
|
||||
lib_name(),
|
||||
@@ -268,10 +268,15 @@ void Compiled::eval_gpu(
|
||||
args.append<uint32_t>(outputs[0].data_size());
|
||||
}
|
||||
|
||||
// Choose work per thread
|
||||
int work_per_thread = 4;
|
||||
if (!contiguous && shape.back() % work_per_thread != 0) {
|
||||
work_per_thread = 1;
|
||||
}
|
||||
|
||||
// Launch kernel.
|
||||
const char* index_type = large ? "int64_t" : "uint32_t";
|
||||
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
|
||||
int work_per_thread = 4;
|
||||
if (contiguous) {
|
||||
kernel_name +=
|
||||
fmt::format("_contiguous<{}, {}>", index_type, work_per_thread);
|
||||
@@ -288,8 +293,8 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
|
||||
auto kernel = mod.get_kernel(kernel_name);
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, outputs[0], large);
|
||||
num_blocks.x = (num_blocks.x + work_per_thread - 1) / work_per_thread;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, outputs[0], large, work_per_thread);
|
||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user