Strided reduce specialization for small reductions (#826)

* Add small column / general reduction specialization
This commit is contained in:
Jagrit Digani 2024-03-14 09:16:53 -07:00 committed by GitHub
parent 1efee9db09
commit 8dfc376c00
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 123 additions and 8 deletions

View File

@ -6,6 +6,69 @@
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// Small column reduce kernel
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op>
[[kernel]] void col_reduce_small(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
const constant size_t& non_col_reductions [[buffer(8)]],
const constant int* non_col_shapes [[buffer(9)]],
const constant size_t* non_col_strides [[buffer(10)]],
const constant int& non_col_ndim [[buffer(11)]],
uint tid [[thread_position_in_grid]]) {
// Appease the compiler
(void)out_size;
Op op;
U total_val = Op::init;
auto out_idx = tid;
in += elem_to_loc(
out_idx,
shape + non_col_ndim,
strides + non_col_ndim,
ndim - non_col_ndim);
for(uint i = 0; i < non_col_reductions; i++) {
size_t in_idx = elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim);
for(uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) {
U val = static_cast<U>(in[in_idx]);
total_val = op(total_val, val);
}
}
out[out_idx] = total_val;
}
#define instantiate_col_reduce_small(name, itype, otype, op) \
template [[host_name("col_reduce_small_" #name)]] \
[[kernel]] void col_reduce_small<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
const constant size_t& non_col_reductions [[buffer(8)]], \
const constant int* non_col_shapes [[buffer(9)]], \
const constant size_t* non_col_strides [[buffer(10)]], \
const constant int& non_col_ndim [[buffer(11)]], \
uint tid [[thread_position_in_grid]]);
///////////////////////////////////////////////////////////////////////////////
// Column reduce helper
///////////////////////////////////////////////////////////////////////////////
@ -171,9 +234,11 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
///////////////////////////////////////////////////////////////////////////////
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general(name ##tname, type, type, op<type>)
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>)
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
@ -181,4 +246,8 @@ instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And)
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or)
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or)
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And)
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or)

View File

@ -307,13 +307,6 @@ void strided_reduce_general_dispatch(
metal::Device& d,
const Stream& s) {
Dtype out_dtype = out.dtype();
bool is_out_64b_int = is_64b_int(out_dtype);
auto kernel = (is_out_64b_int)
? d.get_kernel(
"col_reduce_general_no_atomics_" + op_name + type_to_name(in))
: d.get_kernel("col_reduce_general_" + op_name + type_to_name(in));
compute_encoder->setComputePipelineState(kernel);
// Prepare the arguments for the kernel
size_t reduction_size = plan.shape.back();
@ -327,6 +320,11 @@ void strided_reduce_general_dispatch(
for (auto s : shape) {
non_col_reductions *= static_cast<size_t>(s);
}
std::vector<int> non_col_shapes = shape;
std::vector<size_t> non_col_strides = strides;
int non_col_ndim = shape.size();
auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes);
for (auto s : rem_shape) {
shape.push_back(s);
@ -336,6 +334,54 @@ void strided_reduce_general_dispatch(
}
int ndim = shape.size();
// Specialize for small dims
if (reduction_size * non_col_reductions < 16) {
// Select kernel
auto kernel =
d.get_kernel("col_reduce_small_" + op_name + type_to_name(in));
compute_encoder->setComputePipelineState(kernel);
// Select block dims
MTL::Size grid_dims = MTL::Size(out_size, 1, 1);
MTL::Size group_dims = MTL::Size(256ul, 1, 1);
if (non_col_ndim == 0) {
non_col_shapes = {1};
non_col_strides = {1};
}
// Encode arrays
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
compute_encoder->setBytes(
strides.data(), strides.size() * sizeof(size_t), 6);
compute_encoder->setBytes(&ndim, sizeof(int), 7);
compute_encoder->setBytes(&non_col_reductions, sizeof(size_t), 8);
compute_encoder->setBytes(
non_col_shapes.data(), non_col_shapes.size() * sizeof(int), 9);
compute_encoder->setBytes(
non_col_strides.data(), non_col_shapes.size() * sizeof(size_t), 10);
compute_encoder->setBytes(&non_col_ndim, sizeof(int), 11);
// Dispatch threads
compute_encoder->dispatchThreads(grid_dims, group_dims);
return;
}
// Select kernel
bool is_out_64b_int = is_64b_int(out_dtype);
auto kernel = (is_out_64b_int)
? d.get_kernel(
"col_reduce_general_no_atomics_" + op_name + type_to_name(in))
: d.get_kernel("col_reduce_general_" + op_name + type_to_name(in));
compute_encoder->setComputePipelineState(kernel);
// Select block dimensions
// Each thread reads 16 inputs to give it more work
uint n_inputs_per_thread = REDUCE_N_READS;