Reductions update (#1351)

This commit is contained in:
Angelos Katharopoulos 2024-11-04 22:25:16 -08:00 committed by GitHub
parent 76f275b4df
commit 248431eb3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 462 additions and 206 deletions

View File

@ -144,6 +144,13 @@ def reduction(op, axis, x):
mx.eval(ys) mx.eval(ys)
def sum_and_add(axis, x, y):
z = x.sum(axis=axis, keepdims=True)
for i in range(50):
z = (z + y).sum(axis=axis, keepdims=True)
mx.eval(z)
def softmax(axis, x): def softmax(axis, x):
ys = [] ys = []
for i in range(100): for i in range(100):
@ -505,5 +512,8 @@ if __name__ == "__main__":
elif args.benchmark == "selu": elif args.benchmark == "selu":
print(bench(selu, x)) print(bench(selu, x))
elif args.benchmark == "sum_and_add":
print(bench(sum_and_add, axis, *xs))
else: else:
raise ValueError("Unknown benchmark") raise ValueError("Unknown benchmark")

View File

@ -319,16 +319,18 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
MTL::ComputePipelineState* get_reduce_init_kernel( MTL::ComputePipelineState* get_reduce_init_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
const std::string& func_name,
const std::string& op_name,
const array& out) { const array& out) {
auto lib = d.get_library(kernel_name, [&]() { auto lib = d.get_library(kernel_name, [&]() {
std::ostringstream kernel_source; std::ostringstream kernel_source;
std::string op_type = op_name(out); std::string op_type = op_name;
op_type[0] = std::toupper(op_name(out)[0]); op_type[0] = std::toupper(op_name[0]);
auto out_type = get_type_string(out.dtype()); auto out_type = get_type_string(out.dtype());
std::string op = op_type + "<" + out_type + ">"; std::string op = op_type + "<" + out_type + ">";
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce(); kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
kernel_source << get_template_definition( kernel_source << get_template_definition(
kernel_name, "init_reduce", out_type, op); kernel_name, func_name, out_type, op);
return kernel_source.str(); return kernel_source.str();
}); });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);

View File

@ -79,6 +79,8 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
MTL::ComputePipelineState* get_reduce_init_kernel( MTL::ComputePipelineState* get_reduce_init_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
const std::string& func_name,
const std::string& op_name,
const array& out); const array& out);
MTL::ComputePipelineState* get_reduce_kernel( MTL::ComputePipelineState* get_reduce_kernel(

View File

@ -113,9 +113,12 @@ instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
// special case bool with larger output type // special case bool with larger output type
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>) instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \ #define instantiate_col_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \ instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
col_reduce_small, \ col_reduce_small, \
itype, otype, op, dim) \
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
col_reduce_longcolumn, \
itype, otype, op, dim) itype, otype, op, dim)
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \ #define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
@ -123,9 +126,14 @@ instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
col_reduce_looped, \ col_reduce_looped, \
itype, otype, op, dim, bm, bn) itype, otype, op, dim, bm, bn)
#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_2pass, \
itype, otype, op, dim, bm, bn)
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \ #define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 8, 128) \ instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) \
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, 32, 32)
#define instantiate_col_reduce_general(name, itype, otype, op) \ #define instantiate_col_reduce_general(name, itype, otype, op) \
instantiate_col_reduce_small(name, itype, otype, op, 0) \ instantiate_col_reduce_small(name, itype, otype, op, 0) \

View File

@ -1,11 +1,6 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
template < template <typename T, typename U, typename Op, int NDIMS>
typename T,
typename U,
typename Op,
int NDIMS,
int N_READS = REDUCE_N_READS>
[[kernel]] void col_reduce_small( [[kernel]] void col_reduce_small(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device U* out [[buffer(1)]], device U* out [[buffer(1)]],
@ -20,170 +15,128 @@ template <
const constant size_t& non_col_reductions [[buffer(10)]], const constant size_t& non_col_reductions [[buffer(10)]],
uint3 gid [[threadgroup_position_in_grid]], uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]], uint3 gsize [[threadgroups_per_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint3 lid [[thread_position_in_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 lsize [[threads_per_threadgroup]]) {
uint3 tid [[thread_position_in_grid]], constexpr int n_reads = 4;
uint3 tsize [[threads_per_grid]]) {
Op op; Op op;
looped_elem_to_loc<NDIMS> loop; looped_elem_to_loc<NDIMS> loop;
const device T* row; const device T* row;
// Case 1: Small row small column U totals[n_reads];
if (reduction_size * non_col_reductions < 64 && reduction_stride < 32) { for (int i = 0; i < n_reads; i++) {
U totals[31]; totals[i] = Op::init;
for (int i = 0; i < 31; i++) { }
totals[i] = Op::init;
size_t column = size_t(gid.x) * lsize.x * n_reads + lid.x * n_reads;
if (column >= reduction_stride) {
return;
}
bool safe = column + n_reads <= reduction_stride;
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
in += in_idx + column;
size_t total_rows = non_col_reductions * reduction_size;
loop.next(lid.y, reduce_shape, reduce_strides);
for (size_t r = lid.y; r < total_rows; r += lsize.y) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
if (safe) {
for (int i = 0; i < n_reads; i++) {
totals[i] = op(static_cast<U>(row[i]), totals[i]);
}
} else {
U vals[n_reads];
for (int i = 0; i < n_reads; i++) {
vals[i] =
(column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
}
for (int i = 0; i < n_reads; i++) {
totals[i] = op(vals[i], totals[i]);
}
} }
loop.next(lsize.y, reduce_shape, reduce_strides);
}
short stride = reduction_stride; if (lsize.y > 1) {
short size = reduction_size; // lsize.y should be <= 8
short blocks = stride / N_READS; threadgroup U shared_vals[32 * 8 * n_reads];
short extra = stride - blocks * N_READS; for (int i = 0; i < n_reads; i++) {
shared_vals[lid.y * lsize.x * n_reads + lid.x * n_reads + i] = totals[i];
size_t out_idx = tid.x + tsize.y * size_t(tid.y); }
in += elem_to_loc(out_idx, shape, strides, ndim); threadgroup_barrier(mem_flags::mem_threadgroup);
if (lid.y == 0) {
for (uint r = 0; r < non_col_reductions; r++) { for (int i = 0; i < n_reads; i++) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); totals[i] = shared_vals[lid.x * n_reads + i];
}
for (short i = 0; i < size; i++) { for (uint j = 1; j < lsize.y; j++) {
for (short j = 0; j < blocks; j++) { for (int i = 0; i < n_reads; i++) {
for (short k = 0; k < N_READS; k++) { totals[i] =
totals[j * N_READS + k] = op(shared_vals[j * lsize.x * n_reads + lid.x * n_reads + i],
op(totals[j * N_READS + k], totals[i]);
static_cast<U>(row[i * stride + j * N_READS + k]));
}
}
for (short k = 0; k < extra; k++) {
totals[blocks * N_READS + k] =
op(totals[blocks * N_READS + k],
static_cast<U>(row[i * stride + blocks * N_READS + k]));
} }
} }
loop.next(reduce_shape, reduce_strides);
}
out += out_idx * reduction_stride;
for (short j = 0; j < stride; j++) {
out[j] = totals[j];
} }
} }
// Case 2: Long row small column if (lid.y == 0) {
else if (reduction_size * non_col_reductions < 32) { out += out_idx * reduction_stride + column;
U totals[N_READS];
for (int i = 0; i < N_READS; i++) {
totals[i] = Op::init;
}
short size = reduction_size;
size_t offset = size_t(tid.x) * N_READS;
bool safe = offset + N_READS <= reduction_stride;
short extra = reduction_stride - offset;
size_t out_idx = tid.y + tsize.z * size_t(tid.z);
in += elem_to_loc(out_idx, shape, strides, ndim) + offset;
for (uint r = 0; r < non_col_reductions; r++) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
if (safe) {
for (short i = 0; i < size; i++) {
for (short j = 0; j < N_READS; j++) {
totals[j] =
op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
}
}
} else {
for (short i = 0; i < size; i++) {
for (short j = 0; j < extra; j++) {
totals[j] =
op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
}
}
}
loop.next(reduce_shape, reduce_strides);
}
out += out_idx * reduction_stride + offset;
if (safe) { if (safe) {
for (short i = 0; i < N_READS; i++) { for (int i = 0; i < n_reads; i++) {
out[i] = totals[i]; out[i] = totals[i];
} }
} else { } else {
for (short i = 0; i < extra; i++) { for (int i = 0; column + i < reduction_stride; i++) {
out[i] = totals[i]; out[i] = totals[i];
} }
} }
} }
}
// Case 3: Long row medium column template <typename T, typename U, typename Op, int NDIMS>
else { [[kernel]] void col_reduce_longcolumn(
threadgroup U shared_vals[1024]; const device T* in [[buffer(0)]],
U totals[N_READS]; device U* out [[buffer(1)]],
for (int i = 0; i < N_READS; i++) { const constant size_t& reduction_size [[buffer(2)]],
totals[i] = Op::init; const constant size_t& reduction_stride [[buffer(3)]],
} const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
short stride = reduction_stride; const constant int& ndim [[buffer(6)]],
short lid = simd_group_id * simd_size + simd_lane_id; const constant int* reduce_shape [[buffer(7)]],
short2 tile((stride + N_READS - 1) / N_READS, 32); const constant size_t* reduce_strides [[buffer(8)]],
short2 offset((lid % tile.x) * N_READS, lid / tile.x); const constant int& reduce_ndim [[buffer(9)]],
short sm_stride = tile.x * N_READS; const constant size_t& non_col_reductions [[buffer(10)]],
bool safe = offset.x + N_READS <= stride; const constant size_t& out_size [[buffer(11)]],
uint3 gid [[threadgroup_position_in_grid]],
size_t out_idx = gid.y + gsize.y * size_t(gid.z); uint3 gsize [[threadgroups_per_grid]],
in += elem_to_loc(out_idx, shape, strides, ndim) + offset.x; uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]) {
// Read cooperatively and contiguously and aggregate the partial results. Op op;
size_t total = non_col_reductions * reduction_size; looped_elem_to_loc<NDIMS> loop;
loop.next(offset.y, reduce_shape, reduce_strides); const device T* row;
for (size_t r = offset.y; r < total; r += simd_size) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); size_t out_idx = gid.x + gsize.x * size_t(gid.y);
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
if (safe) { in += in_idx + lid.x;
for (int i = 0; i < N_READS; i++) {
totals[i] = op(static_cast<U>(row[i]), totals[i]); U total = Op::init;
} size_t total_rows = non_col_reductions * reduction_size;
} else { loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides);
U vals[N_READS]; for (size_t r = gid.z * lsize.y + lid.y; r < total_rows;
for (int i = 0; i < N_READS; i++) { r += lsize.y * gsize.z) {
vals[i] = (offset.x + i < stride) ? static_cast<U>(row[i]) : op.init; row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
} total = op(static_cast<U>(*row), total);
for (int i = 0; i < N_READS; i++) { loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides);
totals[i] = op(vals[i], totals[i]); }
}
} threadgroup U shared_vals[32 * 32];
shared_vals[lid.y * lsize.x + lid.x] = total;
loop.next(simd_size, reduce_shape, reduce_strides); threadgroup_barrier(mem_flags::mem_threadgroup);
} if (lid.y == 0) {
for (uint i = 1; i < lsize.y; i++) {
// Each thread holds N_READS partial results but the simdgroups are not total = op(total, shared_vals[i * lsize.x + lid.x]);
// aligned to do the reduction across the simdgroup so we write our results
// in the shared memory and read them back according to the simdgroup.
for (int i = 0; i < N_READS; i++) {
shared_vals[offset.y * sm_stride + offset.x + i] = totals[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < N_READS; i++) {
totals[i] = op.simd_reduce(
shared_vals[simd_lane_id * sm_stride + simd_group_id * N_READS + i]);
}
// Write the output.
if (simd_lane_id == 0) {
short column = simd_group_id * N_READS;
out += out_idx * reduction_stride + column;
if (column + N_READS <= stride) {
for (int i = 0; i < N_READS; i++) {
out[i] = totals[i];
}
} else {
for (int i = 0; column + i < stride; i++) {
out[i] = totals[i];
}
}
} }
out[gid.z * out_size + out_idx * reduction_stride + lid.x] = total;
} }
} }
@ -216,7 +169,7 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op; Op op;
constexpr int n_simdgroups = 4; constexpr int n_simdgroups = 8;
constexpr short tgp_size = n_simdgroups * simd_size; constexpr short tgp_size = n_simdgroups * simd_size;
constexpr short n_reads = (BM * BN) / tgp_size; constexpr short n_reads = (BM * BN) / tgp_size;
constexpr short n_read_blocks = BN / n_reads; constexpr short n_read_blocks = BN / n_reads;
@ -329,3 +282,103 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
} }
} }
} }
template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
[[kernel]] void col_reduce_2pass(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]],
const constant size_t& non_col_reductions [[buffer(10)]],
const constant size_t& out_size [[buffer(11)]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
constexpr int n_simdgroups = 8;
constexpr short tgp_size = n_simdgroups * simd_size;
constexpr short n_reads = (BM * BN) / tgp_size;
constexpr short n_read_blocks = BN / n_reads;
constexpr int n_outputs = BN / n_simdgroups;
constexpr short outer_blocks = 32;
static_assert(BM == 32, "BM should be equal to 32");
threadgroup U shared_vals[BN * BM];
U totals[n_reads];
looped_elem_to_loc<NDIMS> loop;
const device T* row;
for (int i = 0; i < n_reads; i++) {
totals[i] = Op::init;
}
short lid = simd_group_id * simd_size + simd_lane_id;
short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
size_t column = BN * gid.x + offset.x;
bool safe = column + n_reads <= reduction_stride;
size_t full_idx = gid.y + gsize.y * size_t(gid.z);
size_t block_idx = full_idx / out_size;
size_t out_idx = full_idx % out_size;
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
in += in_idx + column;
size_t total = non_col_reductions * reduction_size;
loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides);
for (size_t r = offset.y + block_idx * BM; r < total;
r += outer_blocks * BM) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
if (safe) {
for (int i = 0; i < n_reads; i++) {
totals[i] = op(static_cast<U>(row[i]), totals[i]);
}
} else {
U vals[n_reads];
for (int i = 0; i < n_reads; i++) {
vals[i] =
(column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
}
for (int i = 0; i < n_reads; i++) {
totals[i] = op(vals[i], totals[i]);
}
}
loop.next(outer_blocks * BM, reduce_shape, reduce_strides);
}
// We can use a simd reduction to accumulate across BM so each thread writes
// the partial output to SM and then each simdgroup does BN / n_simdgroups
// accumulations.
for (int i = 0; i < n_reads; i++) {
shared_vals[offset.y * BN + offset.x + i] = totals[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
short2 out_offset(simd_group_id * n_outputs, simd_lane_id);
for (int i = 0; i < n_outputs; i++) {
totals[i] =
op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);
}
// Write the output.
if (simd_lane_id == 0) {
size_t out_column = BN * gid.x + out_offset.x;
out += full_idx * reduction_stride + out_column;
if (out_column + n_outputs <= reduction_stride) {
for (int i = 0; i < n_outputs; i++) {
out[i] = totals[i];
}
} else {
for (int i = 0; out_column + i < reduction_stride; i++) {
out[i] = totals[i];
}
}
}
}

View File

@ -97,6 +97,8 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
MTL::ComputePipelineState* get_reduce_init_kernel( MTL::ComputePipelineState* get_reduce_init_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
const std::string&,
const std::string&,
const array&) { const array&) {
return d.get_kernel(kernel_name); return d.get_kernel(kernel_name);
} }

View File

@ -141,6 +141,20 @@ struct ColReduceArgs {
ndim = shape.size(); ndim = shape.size();
} }
/**
* Create the col reduce arguments for reducing the 1st axis of the row
* contiguous intermediate array.
*/
ColReduceArgs(const array& intermediate) {
assert(intermediate.flags().row_contiguous);
reduction_size = intermediate.shape(0);
reduction_stride = intermediate.size() / reduction_size;
non_col_reductions = 1;
reduce_ndim = 0;
ndim = 0;
}
void encode(CommandEncoder& compute_encoder) { void encode(CommandEncoder& compute_encoder) {
// Push 0s to avoid encoding empty vectors. // Push 0s to avoid encoding empty vectors.
if (reduce_ndim == 0) { if (reduce_ndim == 0) {
@ -231,8 +245,10 @@ void init_reduce(
CommandEncoder& compute_encoder, CommandEncoder& compute_encoder,
metal::Device& d, metal::Device& d,
const Stream& s) { const Stream& s) {
auto kernel = get_reduce_init_kernel( std::ostringstream kname;
d, "init_reduce_" + op_name + type_to_name(out), out); const std::string func_name = "init_reduce";
kname << func_name << "_" << op_name << type_to_name(out);
auto kernel = get_reduce_init_kernel(d, kname.str(), func_name, op_name, out);
size_t nthreads = out.size(); size_t nthreads = out.size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
@ -251,8 +267,7 @@ void all_reduce_dispatch(
const std::string& op_name, const std::string& op_name,
CommandEncoder& compute_encoder, CommandEncoder& compute_encoder,
metal::Device& d, metal::Device& d,
const Stream& s, const Stream& s) {
std::vector<array>& copies) {
// Set the kernel // Set the kernel
std::ostringstream kname; std::ostringstream kname;
const std::string func_name = "all_reduce"; const std::string func_name = "all_reduce";
@ -293,7 +308,7 @@ void all_reduce_dispatch(
// Allocate an intermediate tensor to hold results if needed // Allocate an intermediate tensor to hold results if needed
array intermediate({n_rows}, out.dtype(), nullptr, {}); array intermediate({n_rows}, out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
copies.push_back(intermediate); d.add_temporary(intermediate, s.index);
// 1st pass // 1st pass
size_t row_size = (in_size + n_rows - 1) / n_rows; size_t row_size = (in_size + n_rows - 1) / n_rows;
@ -469,39 +484,11 @@ void strided_reduce_small(
// Figure out the grid dims // Figure out the grid dims
MTL::Size grid_dims, group_dims; MTL::Size grid_dims, group_dims;
// Case 1: Small row small column // Prepare the arguments for the kernel
if (args.reduction_size * args.non_col_reductions < 64 && args.reduce_shape.push_back(args.reduction_size);
args.reduction_stride < 32) { args.reduce_strides.push_back(args.reduction_stride);
grid_dims = output_grid_for_col_reduce(out, args); args.reduce_ndim++;
int threadgroup_size = (grid_dims.width > 128) ? 128 : grid_dims.width;
group_dims = MTL::Size(threadgroup_size, 1, 1);
}
// Case 2: Long row small column
else if (args.reduction_size * args.non_col_reductions < 32) {
auto out_grid_dims = output_grid_for_col_reduce(out, args);
int threads_x =
(args.reduction_stride + REDUCE_N_READS - 1) / REDUCE_N_READS;
int threadgroup_x = std::min(threads_x, 128);
grid_dims = MTL::Size(threads_x, out_grid_dims.width, out_grid_dims.height);
group_dims = MTL::Size(threadgroup_x, 1, 1);
}
// Case 3: Long row medium column
else {
args.reduce_shape.push_back(args.reduction_size);
args.reduce_strides.push_back(args.reduction_stride);
args.reduce_ndim++;
int simdgroups =
(args.reduction_stride + REDUCE_N_READS - 1) / REDUCE_N_READS;
int threadgroup_size = simdgroups * 32;
auto out_grid_dims = output_grid_for_col_reduce(out, args);
grid_dims =
MTL::Size(threadgroup_size, out_grid_dims.width, out_grid_dims.height);
group_dims = MTL::Size(threadgroup_size, 1, 1);
}
// Set the kernel
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0; int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
std::ostringstream kname; std::ostringstream kname;
const std::string func_name = "col_reduce_small"; const std::string func_name = "col_reduce_small";
@ -510,10 +497,113 @@ void strided_reduce_small(
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n); get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
const int n_reads = 4;
size_t reduction_stride_blocks =
(args.reduction_stride + n_reads - 1) / n_reads;
size_t total = args.reduction_size * args.non_col_reductions;
size_t threadgroup_x = std::min(reduction_stride_blocks, 32ul);
size_t threadgroup_y = std::min(
8ul,
std::min(kernel->maxTotalThreadsPerThreadgroup() / threadgroup_x, total));
group_dims = MTL::Size(threadgroup_x, threadgroup_y, 1);
grid_dims = output_grid_for_col_reduce(out, args);
grid_dims = MTL::Size(
(reduction_stride_blocks + threadgroup_x - 1) / threadgroup_x,
grid_dims.width,
grid_dims.height);
// Launch // Launch
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
args.encode(compute_encoder); args.encode(compute_encoder);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
void strided_reduce_longcolumn(
const array& in,
array& out,
const std::string& op_name,
ColReduceArgs& args,
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s) {
size_t total_reduction_size = args.reduction_size * args.non_col_reductions;
size_t outer_blocks = 32;
if (total_reduction_size >= 32768) {
outer_blocks = 128;
}
// Prepare the temporary accumulator
std::vector<int> intermediate_shape;
intermediate_shape.reserve(out.ndim() + 1);
intermediate_shape.push_back(outer_blocks);
intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end());
array intermediate(std::move(intermediate_shape), out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
d.add_temporary(intermediate, s.index);
// Prepare the arguments for the kernel
args.reduce_shape.push_back(args.reduction_size);
args.reduce_strides.push_back(args.reduction_stride);
args.reduce_ndim++;
// Figure out the grid dims
size_t out_size = out.size();
size_t threadgroup_x = args.reduction_stride;
size_t threadgroup_y =
(args.non_col_reductions * args.reduction_size + outer_blocks - 1) /
outer_blocks;
threadgroup_y = std::min(32ul, threadgroup_y);
auto out_grid_size = output_grid_for_col_reduce(out, args);
MTL::Size grid_dims(out_grid_size.width, out_grid_size.height, outer_blocks);
MTL::Size group_dims(threadgroup_x, threadgroup_y, 1);
// Set the kernel
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
std::ostringstream kname;
const std::string func_name = "col_reduce_longcolumn";
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
auto kernel =
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
compute_encoder->setComputePipelineState(kernel);
// Launch
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(intermediate, 1);
args.encode(compute_encoder);
compute_encoder->setBytes(&out_size, sizeof(size_t), 11);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Make the 2nd pass arguments and grid_dims
ColReduceArgs second_args(intermediate);
second_args.reduce_shape.push_back(outer_blocks);
second_args.reduce_strides.push_back(out.size());
second_args.reduce_ndim++;
int BN = 32;
grid_dims = MTL::Size(256 * ((out.size() + BN - 1) / BN), 1, 1);
group_dims = MTL::Size(256, 1, 1);
// Set the 2nd kernel
const std::string second_kernel = "col_reduce_looped_1_32_32_reduce_" +
op_name + type_to_name(intermediate);
kernel = get_reduce_kernel(
d,
second_kernel,
"col_reduce_looped",
op_name,
intermediate,
out,
1,
32,
32);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(intermediate, 0);
compute_encoder.set_output_array(out, 1);
second_args.encode(compute_encoder);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
@ -532,9 +622,9 @@ void strided_reduce_looped(
// Figure out the grid dims // Figure out the grid dims
auto out_grid_size = output_grid_for_col_reduce(out, args); auto out_grid_size = output_grid_for_col_reduce(out, args);
int BN = (args.reduction_stride <= 1024) ? 32 : 128; int BN = 32;
int BM = 1024 / BN; int BM = 1024 / BN;
int threadgroup_size = 4 * 32; int threadgroup_size = 8 * 32;
MTL::Size grid_dims( MTL::Size grid_dims(
threadgroup_size * ((args.reduction_stride + BN - 1) / BN), threadgroup_size * ((args.reduction_stride + BN - 1) / BN),
out_grid_size.width, out_grid_size.width,
@ -558,6 +648,87 @@ void strided_reduce_looped(
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
void strided_reduce_2pass(
const array& in,
array& out,
const std::string& op_name,
ColReduceArgs& args,
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s) {
// Prepare the temporary accumulator
std::vector<int> intermediate_shape;
intermediate_shape.reserve(out.ndim() + 1);
intermediate_shape.push_back(32);
intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end());
array intermediate(std::move(intermediate_shape), out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
d.add_temporary(intermediate, s.index);
// Prepare the arguments for the kernel
args.reduce_shape.push_back(args.reduction_size);
args.reduce_strides.push_back(args.reduction_stride);
args.reduce_ndim++;
// Figure out the grid dims
size_t out_size = out.size() / args.reduction_stride;
auto out_grid_size = output_grid_for_col_reduce(out, args);
int outer_blocks = 32;
int BN = 32;
int BM = 1024 / BN;
int threadgroup_size = 8 * 32;
MTL::Size grid_dims(
threadgroup_size * ((args.reduction_stride + BN - 1) / BN),
out_grid_size.width * outer_blocks,
out_grid_size.height);
MTL::Size group_dims(threadgroup_size, 1, 1);
// Set the kernel
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
std::ostringstream kname;
const std::string func_name = "col_reduce_2pass";
kname << func_name << "_" << n << "_" << BM << "_" << BN << "_reduce_"
<< op_name << type_to_name(in);
auto kernel =
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN);
compute_encoder->setComputePipelineState(kernel);
// Launch
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(intermediate, 1);
args.encode(compute_encoder);
compute_encoder->setBytes(&out_size, sizeof(size_t), 11);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Make the 2nd pass arguments and grid_dims
ColReduceArgs second_args(intermediate);
second_args.reduce_shape.push_back(outer_blocks);
second_args.reduce_strides.push_back(out.size());
second_args.reduce_ndim++;
grid_dims = MTL::Size(threadgroup_size * ((out.size() + BN - 1) / BN), 1, 1);
// Set the 2nd kernel
const std::string second_kernel = "col_reduce_looped_1_32_32_reduce_" +
op_name + type_to_name(intermediate);
kernel = get_reduce_kernel(
d,
second_kernel,
"col_reduce_looped",
op_name,
intermediate,
out,
1,
32,
32);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(intermediate, 0);
compute_encoder.set_output_array(out, 1);
second_args.encode(compute_encoder);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void strided_reduce_general_dispatch( void strided_reduce_general_dispatch(
const array& in, const array& in,
array& out, array& out,
@ -570,11 +741,23 @@ void strided_reduce_general_dispatch(
// Prepare the arguments for the kernel // Prepare the arguments for the kernel
ColReduceArgs args(in, plan, axes); ColReduceArgs args(in, plan, axes);
if (args.reduction_stride < 32 || // Small column
args.reduction_size * args.non_col_reductions < 32) { if (args.reduction_size * args.non_col_reductions < 32) {
return strided_reduce_small(in, out, op_name, args, compute_encoder, d, s); return strided_reduce_small(in, out, op_name, args, compute_encoder, d, s);
} }
// Long column but small row
if (args.reduction_stride < 32 &&
args.reduction_size * args.non_col_reductions >= 1024) {
return strided_reduce_longcolumn(
in, out, op_name, args, compute_encoder, d, s);
}
if (args.reduction_size * args.non_col_reductions > 256 &&
out.size() / 32 < 1024) {
return strided_reduce_2pass(in, out, op_name, args, compute_encoder, d, s);
}
return strided_reduce_looped(in, out, op_name, args, compute_encoder, d, s); return strided_reduce_looped(in, out, op_name, args, compute_encoder, d, s);
} }
@ -620,7 +803,6 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// Reduce // Reduce
if (in.size() > 0) { if (in.size() > 0) {
std::vector<array> copies;
ReductionPlan plan = get_reduction_plan(in, axes_); ReductionPlan plan = get_reduction_plan(in, axes_);
// If it is a general reduce then copy the input to a contiguous array and // If it is a general reduce then copy the input to a contiguous array and
@ -632,7 +814,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
if (plan.type == GeneralReduce) { if (plan.type == GeneralReduce) {
array in_copy(in.shape(), in.dtype(), nullptr, {}); array in_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, in_copy, CopyType::General, s); copy_gpu(in, in_copy, CopyType::General, s);
copies.push_back(in_copy); d.add_temporary(in_copy, s.index);
in = in_copy; in = in_copy;
plan = get_reduction_plan(in, axes_); plan = get_reduction_plan(in, axes_);
} }
@ -640,7 +822,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// Reducing over everything and the data is all there no broadcasting or // Reducing over everything and the data is all there no broadcasting or
// slicing etc. // slicing etc.
if (plan.type == ContiguousAllReduce) { if (plan.type == ContiguousAllReduce) {
all_reduce_dispatch(in, out, op_name, compute_encoder, d, s, copies); all_reduce_dispatch(in, out, op_name, compute_encoder, d, s);
} }
// At least the last dimension is row contiguous and we are reducing over // At least the last dimension is row contiguous and we are reducing over
@ -659,8 +841,6 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
strided_reduce_general_dispatch( strided_reduce_general_dispatch(
in, out, op_name, plan, axes_, compute_encoder, d, s); in, out, op_name, plan, axes_, compute_encoder, d, s);
} }
d.add_temporaries(std::move(copies), s.index);
} }
// Nothing to reduce just initialize the output // Nothing to reduce just initialize the output

View File

@ -16,8 +16,7 @@ void all_reduce_dispatch(
const std::string& op_name, const std::string& op_name,
CommandEncoder& compute_encoder, CommandEncoder& compute_encoder,
metal::Device& d, metal::Device& d,
const Stream& s, const Stream& s);
std::vector<array>& copies);
void row_reduce_general_dispatch( void row_reduce_general_dispatch(
const array& in, const array& in,