mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-07 09:14:34 +08:00
Fix work-per-thread for strided kernels
This commit is contained in:
@@ -205,7 +205,7 @@ void Compiled::eval_gpu(
|
|||||||
builder.os += "\n} // namespace mlx::core::cu\n";
|
builder.os += "\n} // namespace mlx::core::cu\n";
|
||||||
// Build kernel names.
|
// Build kernel names.
|
||||||
std::vector<std::string> kernel_names;
|
std::vector<std::string> kernel_names;
|
||||||
for (auto work_per_thread : std::array<int, 1>{4}) {
|
for (auto work_per_thread : std::array<int, 2>{1, 4}) {
|
||||||
kernel_names.push_back(fmt::format(
|
kernel_names.push_back(fmt::format(
|
||||||
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
|
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
|
||||||
lib_name(),
|
lib_name(),
|
||||||
@@ -268,10 +268,15 @@ void Compiled::eval_gpu(
|
|||||||
args.append<uint32_t>(outputs[0].data_size());
|
args.append<uint32_t>(outputs[0].data_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Choose work per thread
|
||||||
|
int work_per_thread = 4;
|
||||||
|
if (!contiguous && shape.back() % work_per_thread != 0) {
|
||||||
|
work_per_thread = 1;
|
||||||
|
}
|
||||||
|
|
||||||
// Launch kernel.
|
// Launch kernel.
|
||||||
const char* index_type = large ? "int64_t" : "uint32_t";
|
const char* index_type = large ? "int64_t" : "uint32_t";
|
||||||
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
|
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
|
||||||
int work_per_thread = 4;
|
|
||||||
if (contiguous) {
|
if (contiguous) {
|
||||||
kernel_name +=
|
kernel_name +=
|
||||||
fmt::format("_contiguous<{}, {}>", index_type, work_per_thread);
|
fmt::format("_contiguous<{}, {}>", index_type, work_per_thread);
|
||||||
@@ -288,8 +293,8 @@ void Compiled::eval_gpu(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto kernel = mod.get_kernel(kernel_name);
|
auto kernel = mod.get_kernel(kernel_name);
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, outputs[0], large);
|
auto [num_blocks, block_dims] =
|
||||||
num_blocks.x = (num_blocks.x + work_per_thread - 1) / work_per_thread;
|
get_launch_args(kernel, outputs[0], large, work_per_thread);
|
||||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user