mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-28 21:21:21 +08:00
fix
This commit is contained in:
parent
d81c2ec3af
commit
55d6edcaa3
@ -102,7 +102,7 @@ inline void build_kernel(
|
||||
// a third grid dimension
|
||||
os += " int64_t index = N_ * (pos.x + grid.x * int64_t(pos.y));\n";
|
||||
} else if (contiguous) {
|
||||
os += " int index = N_ * pos.x;\n";
|
||||
os += " uint index = N_ * pos.x;\n";
|
||||
} else if (work_per_thread > 1) {
|
||||
os += fmt::format(
|
||||
" int xshape = output_shape[{0}];\n",
|
||||
@ -115,6 +115,9 @@ inline void build_kernel(
|
||||
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
|
||||
idx_type);
|
||||
}
|
||||
if (work_per_thread > 1 && contiguous) {
|
||||
os += " for (int i = 0; i < N_ && index < size; ++i) {\n";
|
||||
}
|
||||
|
||||
// Read constant / contiguous inputs in tmps
|
||||
std::vector<array> nc_inputs;
|
||||
@ -198,13 +201,9 @@ inline void build_kernel(
|
||||
}
|
||||
|
||||
// Open per-thread loop
|
||||
if (work_per_thread > 1) {
|
||||
if (contiguous) {
|
||||
os += " for (int i = 0; i < N_ && index < size; ++i) {\n";
|
||||
} else {
|
||||
os +=
|
||||
" for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
|
||||
}
|
||||
if (work_per_thread > 1 && !contiguous) {
|
||||
os +=
|
||||
" for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
|
||||
}
|
||||
|
||||
// Read non-contiguous inputs into tmps
|
||||
@ -296,7 +295,7 @@ void Compiled::eval_gpu(
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ false,
|
||||
/* use_big_index = */ false,
|
||||
/* work_per_thread */ work_per_thread);
|
||||
/* work_per_thread = */ work_per_thread);
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_contiguous_large",
|
||||
@ -308,7 +307,7 @@ void Compiled::eval_gpu(
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ false,
|
||||
/* use_big_index = */ true,
|
||||
/* work_per_thread */ work_per_thread);
|
||||
/* work_per_thread = */ work_per_thread);
|
||||
for (int i = 1; i < 8; i++) {
|
||||
build_kernel(
|
||||
kernel,
|
||||
@ -497,7 +496,7 @@ void Compiled::eval_gpu(
|
||||
|
||||
// Launch the kernel
|
||||
if (contiguous) {
|
||||
int work_per_thread = get_work_per_thread(outputs_[0].dtype());
|
||||
int work_per_thread = get_work_per_thread(outputs[0].dtype());
|
||||
size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread);
|
||||
MTL::Size group_dims(
|
||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||
|
@ -106,7 +106,6 @@ void ternary_op_gpu_inplace(
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D or 2D grid of threads
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
|
Loading…
Reference in New Issue
Block a user