mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Strided reduce specialization for small reductions (#826)
* Add small column / general reduction specialization
This commit is contained in:
parent
1efee9db09
commit
8dfc376c00
@ -6,6 +6,69 @@
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Small column reduce kernel
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void col_reduce_small(
|
||||
const device T *in [[buffer(0)]],
|
||||
device U *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
const constant size_t& non_col_reductions [[buffer(8)]],
|
||||
const constant int* non_col_shapes [[buffer(9)]],
|
||||
const constant size_t* non_col_strides [[buffer(10)]],
|
||||
const constant int& non_col_ndim [[buffer(11)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
|
||||
// Appease the compiler
|
||||
(void)out_size;
|
||||
|
||||
Op op;
|
||||
U total_val = Op::init;
|
||||
|
||||
auto out_idx = tid;
|
||||
|
||||
in += elem_to_loc(
|
||||
out_idx,
|
||||
shape + non_col_ndim,
|
||||
strides + non_col_ndim,
|
||||
ndim - non_col_ndim);
|
||||
|
||||
for(uint i = 0; i < non_col_reductions; i++) {
|
||||
size_t in_idx = elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim);
|
||||
|
||||
for(uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) {
|
||||
U val = static_cast<U>(in[in_idx]);
|
||||
total_val = op(total_val, val);
|
||||
}
|
||||
}
|
||||
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op) \
|
||||
template [[host_name("col_reduce_small_" #name)]] \
|
||||
[[kernel]] void col_reduce_small<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device otype *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
const constant size_t& non_col_reductions [[buffer(8)]], \
|
||||
const constant int* non_col_shapes [[buffer(9)]], \
|
||||
const constant size_t* non_col_strides [[buffer(10)]], \
|
||||
const constant int& non_col_ndim [[buffer(11)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Column reduce helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -171,9 +234,11 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
||||
instantiate_col_reduce_general(name ##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
||||
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
|
||||
@ -181,4 +246,8 @@ instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce
|
||||
|
||||
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or)
|
||||
|
||||
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or)
|
@ -307,13 +307,6 @@ void strided_reduce_general_dispatch(
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
Dtype out_dtype = out.dtype();
|
||||
bool is_out_64b_int = is_64b_int(out_dtype);
|
||||
auto kernel = (is_out_64b_int)
|
||||
? d.get_kernel(
|
||||
"col_reduce_general_no_atomics_" + op_name + type_to_name(in))
|
||||
: d.get_kernel("col_reduce_general_" + op_name + type_to_name(in));
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Prepare the arguments for the kernel
|
||||
size_t reduction_size = plan.shape.back();
|
||||
@ -327,6 +320,11 @@ void strided_reduce_general_dispatch(
|
||||
for (auto s : shape) {
|
||||
non_col_reductions *= static_cast<size_t>(s);
|
||||
}
|
||||
|
||||
std::vector<int> non_col_shapes = shape;
|
||||
std::vector<size_t> non_col_strides = strides;
|
||||
int non_col_ndim = shape.size();
|
||||
|
||||
auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes);
|
||||
for (auto s : rem_shape) {
|
||||
shape.push_back(s);
|
||||
@ -336,6 +334,54 @@ void strided_reduce_general_dispatch(
|
||||
}
|
||||
int ndim = shape.size();
|
||||
|
||||
// Specialize for small dims
|
||||
if (reduction_size * non_col_reductions < 16) {
|
||||
// Select kernel
|
||||
auto kernel =
|
||||
d.get_kernel("col_reduce_small_" + op_name + type_to_name(in));
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Select block dims
|
||||
MTL::Size grid_dims = MTL::Size(out_size, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(256ul, 1, 1);
|
||||
|
||||
if (non_col_ndim == 0) {
|
||||
non_col_shapes = {1};
|
||||
non_col_strides = {1};
|
||||
}
|
||||
|
||||
// Encode arrays
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
|
||||
compute_encoder->setBytes(
|
||||
strides.data(), strides.size() * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&non_col_reductions, sizeof(size_t), 8);
|
||||
compute_encoder->setBytes(
|
||||
non_col_shapes.data(), non_col_shapes.size() * sizeof(int), 9);
|
||||
compute_encoder->setBytes(
|
||||
non_col_strides.data(), non_col_shapes.size() * sizeof(size_t), 10);
|
||||
compute_encoder->setBytes(&non_col_ndim, sizeof(int), 11);
|
||||
|
||||
// Dispatch threads
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Select kernel
|
||||
bool is_out_64b_int = is_64b_int(out_dtype);
|
||||
auto kernel = (is_out_64b_int)
|
||||
? d.get_kernel(
|
||||
"col_reduce_general_no_atomics_" + op_name + type_to_name(in))
|
||||
: d.get_kernel("col_reduce_general_" + op_name + type_to_name(in));
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Select block dimensions
|
||||
// Each thread reads 16 inputs to give it more work
|
||||
uint n_inputs_per_thread = REDUCE_N_READS;
|
||||
|
Loading…
Reference in New Issue
Block a user