mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Re-tune
This commit is contained in:
@@ -163,35 +163,60 @@ __global__ void row_reduce_looped(
|
||||
U init = ReduceInit<Op, T>::value();
|
||||
total[0] = init;
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
size_t full_blocks = args.row_size / (block.size() * N_READS);
|
||||
size_t final_offset = full_blocks * (block.size() * N_READS);
|
||||
const size_t full_blocks = args.row_size / (block.size() * N_READS);
|
||||
const size_t final_offset = full_blocks * (block.size() * N_READS);
|
||||
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
in += block.thread_rank() * N_READS;
|
||||
|
||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||
const T* inlocal = in + loop.location();
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
auto vals = load_vector<N_READS>(inlocal, 0);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total[0] = op(total[0], cast_to<U>(vals[i]));
|
||||
}
|
||||
inlocal += block.size() * N_READS;
|
||||
// Unaligned reduce
|
||||
if (final_offset < args.row_size) {
|
||||
bool mask[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
mask[i] =
|
||||
(final_offset + block.thread_rank() * N_READS + i) < args.row_size;
|
||||
}
|
||||
if (final_offset < args.row_size) {
|
||||
T vals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] =
|
||||
((final_offset + block.thread_rank() * N_READS + i) < args.row_size)
|
||||
? inlocal[i]
|
||||
: cast_to<T>(init);
|
||||
|
||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||
const T* inlocal = in + loop.location();
|
||||
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
auto vals = load_vector<N_READS>(inlocal, 0);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total[0] = op(total[0], cast_to<U>(vals[i]));
|
||||
}
|
||||
inlocal += block.size() * N_READS;
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total[0] = op(total[0], cast_to<U>(vals[i]));
|
||||
|
||||
{
|
||||
T vals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = mask[i] ? inlocal[i] : cast_to<T>(init);
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total[0] = op(total[0], cast_to<U>(vals[i]));
|
||||
}
|
||||
}
|
||||
|
||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
}
|
||||
|
||||
// Aligned case
|
||||
else {
|
||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||
const T* inlocal = in + loop.location();
|
||||
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
auto vals = load_vector<N_READS>(inlocal, 0);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total[0] = op(total[0], cast_to<U>(vals[i]));
|
||||
}
|
||||
inlocal += block.size() * N_READS;
|
||||
}
|
||||
|
||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
// TODO: Maybe block.sync() here?
|
||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
__shared__ U shared_accumulators[32];
|
||||
@@ -231,18 +256,9 @@ void row_reduce_simple(
|
||||
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
|
||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
|
||||
if (warps > 128) {
|
||||
warps /= 4;
|
||||
if (warps > 32) {
|
||||
warps = 32;
|
||||
} else {
|
||||
warps = 16;
|
||||
}
|
||||
int best = reductions;
|
||||
for (int j = warps; j >= 1; j /= 2) {
|
||||
int t = reductions % (j * WARP_SIZE);
|
||||
if (t < best) {
|
||||
warps = j;
|
||||
best = t;
|
||||
}
|
||||
}
|
||||
int threads = warps * WARP_SIZE;
|
||||
dim3 block(threads, 1, 1);
|
||||
@@ -289,18 +305,9 @@ void row_reduce_looped(
|
||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
|
||||
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
|
||||
if (warps > 128) {
|
||||
warps /= 4;
|
||||
if (warps > 32) {
|
||||
warps = 32;
|
||||
} else {
|
||||
warps = 16;
|
||||
}
|
||||
int best = reductions;
|
||||
for (int j = warps; j >= 1; j /= 2) {
|
||||
int t = reductions % (j * WARP_SIZE);
|
||||
if (t < best) {
|
||||
warps = j;
|
||||
best = t;
|
||||
}
|
||||
}
|
||||
int threads = warps * WARP_SIZE;
|
||||
dim3 block(threads, 1, 1);
|
||||
|
||||
Reference in New Issue
Block a user