mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 18:11:15 +08:00
Further reduction tuning (#1349)
* More reduction tuning * Forgotten pdb * Small column long row specialization
This commit is contained in:
parent
da8deb2b62
commit
b57a52813b
@ -28,10 +28,8 @@ template <
|
|||||||
looped_elem_to_loc<NDIMS> loop;
|
looped_elem_to_loc<NDIMS> loop;
|
||||||
const device T* row;
|
const device T* row;
|
||||||
|
|
||||||
// Case 1:
|
// Case 1: Small row small column
|
||||||
// reduction_stride is small, reduction_size is small and non_col_reductions
|
if (reduction_size * non_col_reductions < 64 && reduction_stride < 32) {
|
||||||
// is small. Each thread computes reduction_stride outputs.
|
|
||||||
if (reduction_size * non_col_reductions < 64) {
|
|
||||||
U totals[31];
|
U totals[31];
|
||||||
for (int i = 0; i < 31; i++) {
|
for (int i = 0; i < 31; i++) {
|
||||||
totals[i] = Op::init;
|
totals[i] = Op::init;
|
||||||
@ -71,10 +69,55 @@ template <
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Case 2:
|
// Case 2: Long row small column
|
||||||
// Reduction stride is small but everything else can be big. We loop both
|
else if (reduction_size * non_col_reductions < 32) {
|
||||||
// across reduction size and non_col_reductions. Each simdgroup produces
|
U totals[N_READS];
|
||||||
// N_READS outputs.
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
totals[i] = Op::init;
|
||||||
|
}
|
||||||
|
|
||||||
|
short size = reduction_size;
|
||||||
|
size_t offset = size_t(tid.x) * N_READS;
|
||||||
|
bool safe = offset + N_READS <= reduction_stride;
|
||||||
|
short extra = reduction_stride - offset;
|
||||||
|
|
||||||
|
size_t out_idx = tid.y + tsize.z * size_t(tid.z);
|
||||||
|
in += elem_to_loc(out_idx, shape, strides, ndim) + offset;
|
||||||
|
|
||||||
|
for (uint r = 0; r < non_col_reductions; r++) {
|
||||||
|
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||||
|
|
||||||
|
if (safe) {
|
||||||
|
for (short i = 0; i < size; i++) {
|
||||||
|
for (short j = 0; j < N_READS; j++) {
|
||||||
|
totals[j] =
|
||||||
|
op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (short i = 0; i < size; i++) {
|
||||||
|
for (short j = 0; j < extra; j++) {
|
||||||
|
totals[j] =
|
||||||
|
op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
loop.next(reduce_shape, reduce_strides);
|
||||||
|
}
|
||||||
|
out += out_idx * reduction_stride + offset;
|
||||||
|
if (safe) {
|
||||||
|
for (short i = 0; i < N_READS; i++) {
|
||||||
|
out[i] = totals[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (short i = 0; i < extra; i++) {
|
||||||
|
out[i] = totals[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case 3: Long row medium column
|
||||||
else {
|
else {
|
||||||
threadgroup U shared_vals[1024];
|
threadgroup U shared_vals[1024];
|
||||||
U totals[N_READS];
|
U totals[N_READS];
|
||||||
@ -147,17 +190,13 @@ template <
|
|||||||
/**
|
/**
|
||||||
* Our approach is the following simple looped approach:
|
* Our approach is the following simple looped approach:
|
||||||
* 1. Each thread keeps running totals for BN / n_simdgroups outputs.
|
* 1. Each thread keeps running totals for BN / n_simdgroups outputs.
|
||||||
* 2. Load a tile BM, BN in shared memory.
|
* 2. Load a tile BM, BN in registers and accumulate in the running totals
|
||||||
* 3. Add the values from shared memory to the current running totals.
|
* 3. Move ahead by BM steps until the column axis and the non column
|
||||||
* Neighboring threads access different rows (transposed acces).
|
* reductions are exhausted.
|
||||||
* 4. Move ahead to the next tile until the M axis is exhausted.
|
* 6. If BM == 32 then transpose in SM and simd reduce the running totals.
|
||||||
* 5. Move ahead to the next non column reduction
|
* Otherwise write in shared memory and BN threads accumulate the running
|
||||||
* 6. Simd reduce the running totals
|
* totals with a loop.
|
||||||
* 7. Write them to the output
|
* 7. Write them to the output
|
||||||
*
|
|
||||||
* The kernel becomes verbose because we support all kinds of OOB checks. For
|
|
||||||
* instance if we choose that reduction_stride must be larger than BN then we
|
|
||||||
* can get rid of half the kernel.
|
|
||||||
*/
|
*/
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
|
@ -202,7 +202,7 @@ inline int threadgroup_size_from_row_size(int row_size) {
|
|||||||
|
|
||||||
// 2 simdgroups per row for medium rows
|
// 2 simdgroups per row for medium rows
|
||||||
if (row_size <= 1024) {
|
if (row_size <= 1024) {
|
||||||
return 64;
|
return 128;
|
||||||
}
|
}
|
||||||
|
|
||||||
// up to 32 simdgroups after that
|
// up to 32 simdgroups after that
|
||||||
@ -458,14 +458,25 @@ void strided_reduce_small(
|
|||||||
// Figure out the grid dims
|
// Figure out the grid dims
|
||||||
MTL::Size grid_dims, group_dims;
|
MTL::Size grid_dims, group_dims;
|
||||||
|
|
||||||
// Case 1: everything is small so launch one thread per col reduce
|
// Case 1: Small row small column
|
||||||
if (args.reduction_size * args.non_col_reductions < 64) {
|
if (args.reduction_size * args.non_col_reductions < 64 &&
|
||||||
|
args.reduction_stride < 32) {
|
||||||
grid_dims = output_grid_for_col_reduce(out, args);
|
grid_dims = output_grid_for_col_reduce(out, args);
|
||||||
int threadgroup_size = (grid_dims.width > 128) ? 128 : grid_dims.width;
|
int threadgroup_size = (grid_dims.width > 128) ? 128 : grid_dims.width;
|
||||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Case 2: Reduction in the simdgroup
|
// Case 2: Long row small column
|
||||||
|
else if (args.reduction_size * args.non_col_reductions < 32) {
|
||||||
|
auto out_grid_dims = output_grid_for_col_reduce(out, args);
|
||||||
|
int threads_x =
|
||||||
|
(args.reduction_stride + REDUCE_N_READS - 1) / REDUCE_N_READS;
|
||||||
|
int threadgroup_x = std::min(threads_x, 128);
|
||||||
|
grid_dims = MTL::Size(threads_x, out_grid_dims.width, out_grid_dims.height);
|
||||||
|
group_dims = MTL::Size(threadgroup_x, 1, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case 3: Long row medium column
|
||||||
else {
|
else {
|
||||||
args.reduce_shape.push_back(args.reduction_size);
|
args.reduce_shape.push_back(args.reduction_size);
|
||||||
args.reduce_strides.push_back(args.reduction_stride);
|
args.reduce_strides.push_back(args.reduction_stride);
|
||||||
@ -508,7 +519,7 @@ void strided_reduce_looped(
|
|||||||
|
|
||||||
// Figure out the grid dims
|
// Figure out the grid dims
|
||||||
auto out_grid_size = output_grid_for_col_reduce(out, args);
|
auto out_grid_size = output_grid_for_col_reduce(out, args);
|
||||||
int BN = (args.reduction_stride <= 256) ? 32 : 128;
|
int BN = (args.reduction_stride <= 1024) ? 32 : 128;
|
||||||
int BM = 1024 / BN;
|
int BM = 1024 / BN;
|
||||||
int threadgroup_size = 4 * 32;
|
int threadgroup_size = 4 * 32;
|
||||||
MTL::Size grid_dims(
|
MTL::Size grid_dims(
|
||||||
@ -544,7 +555,8 @@ void strided_reduce_general_dispatch(
|
|||||||
// Prepare the arguments for the kernel
|
// Prepare the arguments for the kernel
|
||||||
ColReduceArgs args(in, plan, axes);
|
ColReduceArgs args(in, plan, axes);
|
||||||
|
|
||||||
if (args.reduction_stride < 32) {
|
if (args.reduction_stride < 32 ||
|
||||||
|
args.reduction_size * args.non_col_reductions < 32) {
|
||||||
return strided_reduce_small(in, out, op_name, args, compute_encoder, d, s);
|
return strided_reduce_small(in, out, op_name, args, compute_encoder, d, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,7 +16,8 @@ void all_reduce_dispatch(
|
|||||||
const std::string& op_name,
|
const std::string& op_name,
|
||||||
CommandEncoder& compute_encoder,
|
CommandEncoder& compute_encoder,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const Stream& s);
|
const Stream& s,
|
||||||
|
std::vector<array>& copies);
|
||||||
|
|
||||||
void row_reduce_general_dispatch(
|
void row_reduce_general_dispatch(
|
||||||
const array& in,
|
const array& in,
|
||||||
|
@ -43,10 +43,6 @@ class TestReduce(mlx_tests.MLXTestCase):
|
|||||||
z_npy = np.sum(y_npy, axis=a) / 1000
|
z_npy = np.sum(y_npy, axis=a) / 1000
|
||||||
z_mlx = mx.sum(y_mlx, axis=a) / 1000
|
z_mlx = mx.sum(y_mlx, axis=a) / 1000
|
||||||
mx.eval(z_mlx)
|
mx.eval(z_mlx)
|
||||||
if not np.allclose(z_npy, np.array(z_mlx), atol=1e-4):
|
|
||||||
import pdb
|
|
||||||
|
|
||||||
pdb.set_trace()
|
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
np.allclose(z_npy, np.array(z_mlx), atol=1e-4)
|
np.allclose(z_npy, np.array(z_mlx), atol=1e-4)
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user