diff --git a/mlx/backend/metal/kernels/reduction/kernels/reduce_col.metal b/mlx/backend/metal/kernels/reduction/kernels/reduce_col.metal index 07b889052..d757ee6dd 100644 --- a/mlx/backend/metal/kernels/reduction/kernels/reduce_col.metal +++ b/mlx/backend/metal/kernels/reduction/kernels/reduce_col.metal @@ -6,6 +6,69 @@ using namespace metal; +/////////////////////////////////////////////////////////////////////////////// +// Small column reduce kernel +/////////////////////////////////////////////////////////////////////////////// + +template +[[kernel]] void col_reduce_small( + const device T *in [[buffer(0)]], + device U *out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& reduction_stride [[buffer(3)]], + const constant size_t& out_size [[buffer(4)]], + const constant int* shape [[buffer(5)]], + const constant size_t* strides [[buffer(6)]], + const constant int& ndim [[buffer(7)]], + const constant size_t& non_col_reductions [[buffer(8)]], + const constant int* non_col_shapes [[buffer(9)]], + const constant size_t* non_col_strides [[buffer(10)]], + const constant int& non_col_ndim [[buffer(11)]], + uint tid [[thread_position_in_grid]]) { + + // Appease the compiler + (void)out_size; + + Op op; + U total_val = Op::init; + + auto out_idx = tid; + + in += elem_to_loc( + out_idx, + shape + non_col_ndim, + strides + non_col_ndim, + ndim - non_col_ndim); + + for(uint i = 0; i < non_col_reductions; i++) { + size_t in_idx = elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim); + + for(uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) { + U val = static_cast(in[in_idx]); + total_val = op(total_val, val); + } + } + + out[out_idx] = total_val; +} + +#define instantiate_col_reduce_small(name, itype, otype, op) \ + template [[host_name("col_reduce_small_" #name)]] \ + [[kernel]] void col_reduce_small( \ + const device itype *in [[buffer(0)]], \ + device otype *out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& reduction_stride [[buffer(3)]], \ + const constant size_t& out_size [[buffer(4)]], \ + const constant int* shape [[buffer(5)]], \ + const constant size_t* strides [[buffer(6)]], \ + const constant int& ndim [[buffer(7)]], \ + const constant size_t& non_col_reductions [[buffer(8)]], \ + const constant int* non_col_shapes [[buffer(9)]], \ + const constant size_t* non_col_strides [[buffer(10)]], \ + const constant int& non_col_ndim [[buffer(11)]], \ + uint tid [[thread_position_in_grid]]); + /////////////////////////////////////////////////////////////////////////////// // Column reduce helper /////////////////////////////////////////////////////////////////////////////// @@ -171,9 +234,11 @@ template /////////////////////////////////////////////////////////////////////////////// #define instantiate_same_col_reduce_helper(name, tname, type, op) \ + instantiate_col_reduce_small(name ##tname, type, type, op) \ instantiate_col_reduce_general(name ##tname, type, type, op) #define instantiate_same_col_reduce_na_helper(name, tname, type, op) \ + instantiate_col_reduce_small(name ##tname, type, type, op) \ instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op) instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types) @@ -181,4 +246,8 @@ instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum) instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And) -instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or) \ No newline at end of file +instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or) + +instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum) +instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And) +instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or) \ No newline at end of file diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 8a19c602e..c0646017d 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -307,13 +307,6 @@ void strided_reduce_general_dispatch( metal::Device& d, const Stream& s) { Dtype out_dtype = out.dtype(); - bool is_out_64b_int = is_64b_int(out_dtype); - auto kernel = (is_out_64b_int) - ? d.get_kernel( - "col_reduce_general_no_atomics_" + op_name + type_to_name(in)) - : d.get_kernel("col_reduce_general_" + op_name + type_to_name(in)); - - compute_encoder->setComputePipelineState(kernel); // Prepare the arguments for the kernel size_t reduction_size = plan.shape.back(); @@ -327,6 +320,11 @@ void strided_reduce_general_dispatch( for (auto s : shape) { non_col_reductions *= static_cast(s); } + + std::vector non_col_shapes = shape; + std::vector non_col_strides = strides; + int non_col_ndim = shape.size(); + auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes); for (auto s : rem_shape) { shape.push_back(s); @@ -336,6 +334,54 @@ void strided_reduce_general_dispatch( } int ndim = shape.size(); + // Specialize for small dims + if (reduction_size * non_col_reductions < 16) { + // Select kernel + auto kernel = + d.get_kernel("col_reduce_small_" + op_name + type_to_name(in)); + compute_encoder->setComputePipelineState(kernel); + + // Select block dims + MTL::Size grid_dims = MTL::Size(out_size, 1, 1); + MTL::Size group_dims = MTL::Size(256ul, 1, 1); + + if (non_col_ndim == 0) { + non_col_shapes = {1}; + non_col_strides = {1}; + } + + // Encode arrays + set_array_buffer(compute_encoder, in, 0); + set_array_buffer(compute_encoder, out, 1); + compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); + compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3); + compute_encoder->setBytes(&out_size, sizeof(size_t), 4); + compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5); + compute_encoder->setBytes( + strides.data(), strides.size() * sizeof(size_t), 6); + compute_encoder->setBytes(&ndim, sizeof(int), 7); + compute_encoder->setBytes(&non_col_reductions, sizeof(size_t), 8); + compute_encoder->setBytes( + non_col_shapes.data(), non_col_shapes.size() * sizeof(int), 9); + compute_encoder->setBytes( + non_col_strides.data(), non_col_shapes.size() * sizeof(size_t), 10); + compute_encoder->setBytes(&non_col_ndim, sizeof(int), 11); + + // Dispatch threads + compute_encoder->dispatchThreads(grid_dims, group_dims); + + return; + } + + // Select kernel + bool is_out_64b_int = is_64b_int(out_dtype); + auto kernel = (is_out_64b_int) + ? d.get_kernel( + "col_reduce_general_no_atomics_" + op_name + type_to_name(in)) + : d.get_kernel("col_reduce_general_" + op_name + type_to_name(in)); + + compute_encoder->setComputePipelineState(kernel); + // Select block dimensions // Each thread reads 16 inputs to give it more work uint n_inputs_per_thread = REDUCE_N_READS;