This commit is contained in:
Awni Hannun 2025-05-02 17:44:49 -07:00
parent d81c2ec3af
commit 55d6edcaa3
2 changed files with 10 additions and 12 deletions

View File

@ -102,7 +102,7 @@ inline void build_kernel(
// a third grid dimension
os += " int64_t index = N_ * (pos.x + grid.x * int64_t(pos.y));\n";
} else if (contiguous) {
os += " int index = N_ * pos.x;\n";
os += " uint index = N_ * pos.x;\n";
} else if (work_per_thread > 1) {
os += fmt::format(
" int xshape = output_shape[{0}];\n",
@ -115,6 +115,9 @@ inline void build_kernel(
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
idx_type);
}
if (work_per_thread > 1 && contiguous) {
os += " for (int i = 0; i < N_ && index < size; ++i) {\n";
}
// Read constant / contiguous inputs in tmps
std::vector<array> nc_inputs;
@ -198,13 +201,9 @@ inline void build_kernel(
}
// Open per-thread loop
if (work_per_thread > 1) {
if (contiguous) {
os += " for (int i = 0; i < N_ && index < size; ++i) {\n";
} else {
os +=
" for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
}
if (work_per_thread > 1 && !contiguous) {
os +=
" for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
}
// Read non-contiguous inputs into tmps
@ -296,7 +295,7 @@ void Compiled::eval_gpu(
/* ndim = */ 0,
/* dynamic_dims = */ false,
/* use_big_index = */ false,
/* work_per_thread */ work_per_thread);
/* work_per_thread = */ work_per_thread);
build_kernel(
kernel,
kernel_lib_ + "_contiguous_large",
@ -308,7 +307,7 @@ void Compiled::eval_gpu(
/* ndim = */ 0,
/* dynamic_dims = */ false,
/* use_big_index = */ true,
/* work_per_thread */ work_per_thread);
/* work_per_thread = */ work_per_thread);
for (int i = 1; i < 8; i++) {
build_kernel(
kernel,
@ -497,7 +496,7 @@ void Compiled::eval_gpu(
// Launch the kernel
if (contiguous) {
int work_per_thread = get_work_per_thread(outputs_[0].dtype());
int work_per_thread = get_work_per_thread(outputs[0].dtype());
size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread);
MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);

View File

@ -106,7 +106,6 @@ void ternary_op_gpu_inplace(
compute_encoder.dispatch_threads(grid_dims, group_dims);
} else {
// Launch a 1D or 2D grid of threads
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
if (thread_group_size > nthreads) {
thread_group_size = nthreads;