mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Further reduction tuning (#1349)
* More reduction tuning * Forgotten pdb * Small column long row specialization
This commit is contained in:
parent
da8deb2b62
commit
b57a52813b
@ -28,10 +28,8 @@ template <
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
const device T* row;
|
||||
|
||||
// Case 1:
|
||||
// reduction_stride is small, reduction_size is small and non_col_reductions
|
||||
// is small. Each thread computes reduction_stride outputs.
|
||||
if (reduction_size * non_col_reductions < 64) {
|
||||
// Case 1: Small row small column
|
||||
if (reduction_size * non_col_reductions < 64 && reduction_stride < 32) {
|
||||
U totals[31];
|
||||
for (int i = 0; i < 31; i++) {
|
||||
totals[i] = Op::init;
|
||||
@ -71,10 +69,55 @@ template <
|
||||
}
|
||||
}
|
||||
|
||||
// Case 2:
|
||||
// Reduction stride is small but everything else can be big. We loop both
|
||||
// across reduction size and non_col_reductions. Each simdgroup produces
|
||||
// N_READS outputs.
|
||||
// Case 2: Long row small column
|
||||
else if (reduction_size * non_col_reductions < 32) {
|
||||
U totals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = Op::init;
|
||||
}
|
||||
|
||||
short size = reduction_size;
|
||||
size_t offset = size_t(tid.x) * N_READS;
|
||||
bool safe = offset + N_READS <= reduction_stride;
|
||||
short extra = reduction_stride - offset;
|
||||
|
||||
size_t out_idx = tid.y + tsize.z * size_t(tid.z);
|
||||
in += elem_to_loc(out_idx, shape, strides, ndim) + offset;
|
||||
|
||||
for (uint r = 0; r < non_col_reductions; r++) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
|
||||
if (safe) {
|
||||
for (short i = 0; i < size; i++) {
|
||||
for (short j = 0; j < N_READS; j++) {
|
||||
totals[j] =
|
||||
op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (short i = 0; i < size; i++) {
|
||||
for (short j = 0; j < extra; j++) {
|
||||
totals[j] =
|
||||
op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loop.next(reduce_shape, reduce_strides);
|
||||
}
|
||||
out += out_idx * reduction_stride + offset;
|
||||
if (safe) {
|
||||
for (short i = 0; i < N_READS; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
} else {
|
||||
for (short i = 0; i < extra; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Case 3: Long row medium column
|
||||
else {
|
||||
threadgroup U shared_vals[1024];
|
||||
U totals[N_READS];
|
||||
@ -147,17 +190,13 @@ template <
|
||||
/**
|
||||
* Our approach is the following simple looped approach:
|
||||
* 1. Each thread keeps running totals for BN / n_simdgroups outputs.
|
||||
* 2. Load a tile BM, BN in shared memory.
|
||||
* 3. Add the values from shared memory to the current running totals.
|
||||
* Neighboring threads access different rows (transposed acces).
|
||||
* 4. Move ahead to the next tile until the M axis is exhausted.
|
||||
* 5. Move ahead to the next non column reduction
|
||||
* 6. Simd reduce the running totals
|
||||
* 2. Load a tile BM, BN in registers and accumulate in the running totals
|
||||
* 3. Move ahead by BM steps until the column axis and the non column
|
||||
* reductions are exhausted.
|
||||
* 6. If BM == 32 then transpose in SM and simd reduce the running totals.
|
||||
* Otherwise write in shared memory and BN threads accumulate the running
|
||||
* totals with a loop.
|
||||
* 7. Write them to the output
|
||||
*
|
||||
* The kernel becomes verbose because we support all kinds of OOB checks. For
|
||||
* instance if we choose that reduction_stride must be larger than BN then we
|
||||
* can get rid of half the kernel.
|
||||
*/
|
||||
template <
|
||||
typename T,
|
||||
|
@ -202,7 +202,7 @@ inline int threadgroup_size_from_row_size(int row_size) {
|
||||
|
||||
// 2 simdgroups per row for medium rows
|
||||
if (row_size <= 1024) {
|
||||
return 64;
|
||||
return 128;
|
||||
}
|
||||
|
||||
// up to 32 simdgroups after that
|
||||
@ -458,14 +458,25 @@ void strided_reduce_small(
|
||||
// Figure out the grid dims
|
||||
MTL::Size grid_dims, group_dims;
|
||||
|
||||
// Case 1: everything is small so launch one thread per col reduce
|
||||
if (args.reduction_size * args.non_col_reductions < 64) {
|
||||
// Case 1: Small row small column
|
||||
if (args.reduction_size * args.non_col_reductions < 64 &&
|
||||
args.reduction_stride < 32) {
|
||||
grid_dims = output_grid_for_col_reduce(out, args);
|
||||
int threadgroup_size = (grid_dims.width > 128) ? 128 : grid_dims.width;
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
}
|
||||
|
||||
// Case 2: Reduction in the simdgroup
|
||||
// Case 2: Long row small column
|
||||
else if (args.reduction_size * args.non_col_reductions < 32) {
|
||||
auto out_grid_dims = output_grid_for_col_reduce(out, args);
|
||||
int threads_x =
|
||||
(args.reduction_stride + REDUCE_N_READS - 1) / REDUCE_N_READS;
|
||||
int threadgroup_x = std::min(threads_x, 128);
|
||||
grid_dims = MTL::Size(threads_x, out_grid_dims.width, out_grid_dims.height);
|
||||
group_dims = MTL::Size(threadgroup_x, 1, 1);
|
||||
}
|
||||
|
||||
// Case 3: Long row medium column
|
||||
else {
|
||||
args.reduce_shape.push_back(args.reduction_size);
|
||||
args.reduce_strides.push_back(args.reduction_stride);
|
||||
@ -508,7 +519,7 @@ void strided_reduce_looped(
|
||||
|
||||
// Figure out the grid dims
|
||||
auto out_grid_size = output_grid_for_col_reduce(out, args);
|
||||
int BN = (args.reduction_stride <= 256) ? 32 : 128;
|
||||
int BN = (args.reduction_stride <= 1024) ? 32 : 128;
|
||||
int BM = 1024 / BN;
|
||||
int threadgroup_size = 4 * 32;
|
||||
MTL::Size grid_dims(
|
||||
@ -544,7 +555,8 @@ void strided_reduce_general_dispatch(
|
||||
// Prepare the arguments for the kernel
|
||||
ColReduceArgs args(in, plan, axes);
|
||||
|
||||
if (args.reduction_stride < 32) {
|
||||
if (args.reduction_stride < 32 ||
|
||||
args.reduction_size * args.non_col_reductions < 32) {
|
||||
return strided_reduce_small(in, out, op_name, args, compute_encoder, d, s);
|
||||
}
|
||||
|
||||
|
@ -16,7 +16,8 @@ void all_reduce_dispatch(
|
||||
const std::string& op_name,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s);
|
||||
const Stream& s,
|
||||
std::vector<array>& copies);
|
||||
|
||||
void row_reduce_general_dispatch(
|
||||
const array& in,
|
||||
|
@ -43,10 +43,6 @@ class TestReduce(mlx_tests.MLXTestCase):
|
||||
z_npy = np.sum(y_npy, axis=a) / 1000
|
||||
z_mlx = mx.sum(y_mlx, axis=a) / 1000
|
||||
mx.eval(z_mlx)
|
||||
if not np.allclose(z_npy, np.array(z_mlx), atol=1e-4):
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
self.assertTrue(
|
||||
np.allclose(z_npy, np.array(z_mlx), atol=1e-4)
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user