mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-30 06:12:41 +08:00
fix
This commit is contained in:
parent
d81c2ec3af
commit
55d6edcaa3
@ -102,7 +102,7 @@ inline void build_kernel(
|
|||||||
// a third grid dimension
|
// a third grid dimension
|
||||||
os += " int64_t index = N_ * (pos.x + grid.x * int64_t(pos.y));\n";
|
os += " int64_t index = N_ * (pos.x + grid.x * int64_t(pos.y));\n";
|
||||||
} else if (contiguous) {
|
} else if (contiguous) {
|
||||||
os += " int index = N_ * pos.x;\n";
|
os += " uint index = N_ * pos.x;\n";
|
||||||
} else if (work_per_thread > 1) {
|
} else if (work_per_thread > 1) {
|
||||||
os += fmt::format(
|
os += fmt::format(
|
||||||
" int xshape = output_shape[{0}];\n",
|
" int xshape = output_shape[{0}];\n",
|
||||||
@ -115,6 +115,9 @@ inline void build_kernel(
|
|||||||
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
|
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
|
||||||
idx_type);
|
idx_type);
|
||||||
}
|
}
|
||||||
|
if (work_per_thread > 1 && contiguous) {
|
||||||
|
os += " for (int i = 0; i < N_ && index < size; ++i) {\n";
|
||||||
|
}
|
||||||
|
|
||||||
// Read constant / contiguous inputs in tmps
|
// Read constant / contiguous inputs in tmps
|
||||||
std::vector<array> nc_inputs;
|
std::vector<array> nc_inputs;
|
||||||
@ -198,13 +201,9 @@ inline void build_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Open per-thread loop
|
// Open per-thread loop
|
||||||
if (work_per_thread > 1) {
|
if (work_per_thread > 1 && !contiguous) {
|
||||||
if (contiguous) {
|
os +=
|
||||||
os += " for (int i = 0; i < N_ && index < size; ++i) {\n";
|
" for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
|
||||||
} else {
|
|
||||||
os +=
|
|
||||||
" for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read non-contiguous inputs into tmps
|
// Read non-contiguous inputs into tmps
|
||||||
@ -296,7 +295,7 @@ void Compiled::eval_gpu(
|
|||||||
/* ndim = */ 0,
|
/* ndim = */ 0,
|
||||||
/* dynamic_dims = */ false,
|
/* dynamic_dims = */ false,
|
||||||
/* use_big_index = */ false,
|
/* use_big_index = */ false,
|
||||||
/* work_per_thread */ work_per_thread);
|
/* work_per_thread = */ work_per_thread);
|
||||||
build_kernel(
|
build_kernel(
|
||||||
kernel,
|
kernel,
|
||||||
kernel_lib_ + "_contiguous_large",
|
kernel_lib_ + "_contiguous_large",
|
||||||
@ -308,7 +307,7 @@ void Compiled::eval_gpu(
|
|||||||
/* ndim = */ 0,
|
/* ndim = */ 0,
|
||||||
/* dynamic_dims = */ false,
|
/* dynamic_dims = */ false,
|
||||||
/* use_big_index = */ true,
|
/* use_big_index = */ true,
|
||||||
/* work_per_thread */ work_per_thread);
|
/* work_per_thread = */ work_per_thread);
|
||||||
for (int i = 1; i < 8; i++) {
|
for (int i = 1; i < 8; i++) {
|
||||||
build_kernel(
|
build_kernel(
|
||||||
kernel,
|
kernel,
|
||||||
@ -497,7 +496,7 @@ void Compiled::eval_gpu(
|
|||||||
|
|
||||||
// Launch the kernel
|
// Launch the kernel
|
||||||
if (contiguous) {
|
if (contiguous) {
|
||||||
int work_per_thread = get_work_per_thread(outputs_[0].dtype());
|
int work_per_thread = get_work_per_thread(outputs[0].dtype());
|
||||||
size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread);
|
size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread);
|
||||||
MTL::Size group_dims(
|
MTL::Size group_dims(
|
||||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||||
|
@ -106,7 +106,6 @@ void ternary_op_gpu_inplace(
|
|||||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
} else {
|
} else {
|
||||||
// Launch a 1D or 2D grid of threads
|
// Launch a 1D or 2D grid of threads
|
||||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
|
||||||
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
|
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
|
||||||
if (thread_group_size > nthreads) {
|
if (thread_group_size > nthreads) {
|
||||||
thread_group_size = nthreads;
|
thread_group_size = nthreads;
|
||||||
|
Loading…
Reference in New Issue
Block a user