diff --git a/benchmarks/python/comparative/compare.py b/benchmarks/python/comparative/compare.py index c54af3a46..4adde50bc 100644 --- a/benchmarks/python/comparative/compare.py +++ b/benchmarks/python/comparative/compare.py @@ -125,6 +125,14 @@ if __name__ == "__main__": compare_filtered("sum_axis --size 16x128x1024 --axis 1") compare_filtered("sum_axis --size 16x128x1024 --axis 0 --cpu") compare_filtered("sum_axis --size 16x128x1024 --axis 0") + compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --cpu") + compare_filtered("sum_axis --size 16x128x1024 --axis 0,1") + compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --cpu") + compare_filtered("sum_axis --size 16x128x1024 --axis 0,2") + compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1 --cpu") + compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1") + compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1 --cpu") + compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1") compare_filtered("argmax --size 10x1024x128 --axis 1 --cpu") compare_filtered("argmax --size 10x1024x128 --axis 1") compare_filtered("argmax --size 10x1024x128 --axis 2 --cpu") diff --git a/mlx/backend/common/reduce.h b/mlx/backend/common/reduce.h index 740f54a48..da1d1658a 100644 --- a/mlx/backend/common/reduce.h +++ b/mlx/backend/common/reduce.h @@ -126,7 +126,7 @@ struct ReductionPlan { ReductionPlan get_reduction_plan(const array& x, const std::vector axes) { // The data is all there and we are reducing over everything if (x.size() == x.data_size() && axes.size() == x.ndim() && - (x.flags().row_contiguous || x.flags().col_contiguous)) { + x.flags().contiguous) { return ContiguousAllReduce; } diff --git a/mlx/backend/metal/kernels/reduce.metal b/mlx/backend/metal/kernels/reduce.metal index 25bf1ee1f..85ff41f44 100644 --- a/mlx/backend/metal/kernels/reduce.metal +++ b/mlx/backend/metal/kernels/reduce.metal @@ -112,88 +112,33 @@ template uint simd_group_id [[simdgroup_index_in_threadgroup]]); -/////////////////////////////////////////////////////////////////////////////// -// General reduce -/////////////////////////////////////////////////////////////////////////////// - -template -[[kernel]] void general_reduce( - const device T *in [[buffer(0)]], - device mlx_atomic *out [[buffer(1)]], - const device int *in_shape [[buffer(2)]], - const device size_t *in_strides [[buffer(3)]], - const device size_t *out_strides [[buffer(4)]], - const device size_t& ndim [[buffer(5)]], - uint gid [[thread_position_in_grid]]) { - Op op; - auto in_idx = elem_to_loc(gid, in_shape, in_strides, ndim); - auto out_idx = elem_to_loc(gid, in_shape, out_strides, ndim); - op.atomic_update(out, static_cast(in[in_idx]), out_idx); -} - -template -[[kernel]] void general_reduce( - const device T *in [[buffer(0)]], - device mlx_atomic *out [[buffer(1)]], - const device int *in_shape [[buffer(2)]], - const device size_t *in_strides [[buffer(3)]], - const device size_t *out_strides [[buffer(4)]], - uint gid [[thread_position_in_grid]]) { - Op op; - auto in_idx = elem_to_loc_nd(gid, in_shape, in_strides); - auto out_idx = elem_to_loc_nd(gid, in_shape, out_strides); - op.atomic_update(out, static_cast(in[in_idx]), out_idx); -} - -#define instantiate_general_reduce_helper(name, itype, otype, op) \ - template [[host_name("general_reduce_" #name)]] \ - [[kernel]] void general_reduce( \ - const device itype *in [[buffer(0)]], \ - device mlx_atomic *out [[buffer(1)]], \ - const device int *in_shape [[buffer(2)]], \ - const device size_t *in_strides [[buffer(3)]], \ - const device size_t *out_strides [[buffer(4)]], \ - const device size_t& ndim [[buffer(5)]], \ - uint gid [[thread_position_in_grid]]); - -#define instantiate_general_reduce_helper_nd(name, itype, otype, op, n) \ - template [[host_name("general_reduce_" #name "_dim_" #n)]] \ - [[kernel]] void general_reduce( \ - const device itype *in [[buffer(0)]], \ - device mlx_atomic *out [[buffer(1)]], \ - const device int *in_shape [[buffer(2)]], \ - const device size_t *in_strides [[buffer(3)]], \ - const device size_t *out_strides [[buffer(4)]], \ - uint gid [[thread_position_in_grid]]); - -#define instantiate_general_reduce(name, itype, otype, op) \ - instantiate_general_reduce_helper(name, itype, otype, op) \ - instantiate_general_reduce_helper_nd(name, itype, otype, op, 1) \ - instantiate_general_reduce_helper_nd(name, itype, otype, op, 2) \ - instantiate_general_reduce_helper_nd(name, itype, otype, op, 3) \ - instantiate_general_reduce_helper_nd(name, itype, otype, op, 4) - - /////////////////////////////////////////////////////////////////////////////// // Row atomics /////////////////////////////////////////////////////////////////////////////// template -[[kernel]] void row_reduce( +[[kernel]] void row_reduce_general( const device T *in [[buffer(0)]], - device U *out [[buffer(1)]], - const device size_t& reduction_size [[buffer(2)]], - uint lid [[thread_position_in_threadgroup]], - uint lsize [[threads_per_threadgroup]], - uint tid [[threadgroup_position_in_grid]], + device mlx_atomic *out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& out_size [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant size_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_per_group [[simdgroups_per_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { Op op; - // Each threadgroup handles 1 reduction - in += tid * reduction_size + lid * N_READS; + // Each threadgroup handles 1 reduction + // TODO: Specializing elem_to_loc would be slightly faster + int idx = tid.y * out_size + tid.x; + int extra_offset = elem_to_loc(idx, shape, strides, ndim); + in += extra_offset + lid.x * N_READS; // The reduction is accumulated here U total_val = Op::init; @@ -201,7 +146,7 @@ template // Loop over the reduction size within thread group int r = 0; - for (; r < (int)ceildiv(reduction_size, N_READS*lsize) - 1; r++) { + for (; r < (int)ceildiv(reduction_size, N_READS*lsize.x) - 1; r++) { T vals[N_READS]; for(int i = 0; i < N_READS; i++) { vals[i] = in[i]; @@ -210,11 +155,11 @@ template total_val = op(static_cast(vals[i]), total_val); } - in += lsize * N_READS; + in += lsize.x * N_READS; } - // Sepate case for the last set as we close the reduction size - size_t reduction_index = (lid + (size_t)lsize * r) * N_READS; + // Separate case for the last set as we close the reduction size + size_t reduction_index = (lid.x + (size_t)lsize.x * r) * N_READS; if(reduction_index < reduction_size) { int max_reads = reduction_size - reduction_index; @@ -240,26 +185,30 @@ template // Reduction within thread group // Only needed if multiple simd groups if(reduction_size > simd_size) { - total_val = lid < simd_per_group ? local_vals[lid] : op.init; + total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init; total_val = op.simd_reduce(total_val); } // Update output - if (lid == 0) { - out[tid] = total_val; + if (lid.x == 0) { + op.atomic_update(out, total_val, tid.x); } } -#define instantiate_row_reduce(name, itype, otype, op) \ - template [[host_name("row_reduce_" #name)]] \ - [[kernel]] void row_reduce( \ - const device itype *in [[buffer(0)]], \ - device otype *out [[buffer(1)]], \ - const device size_t& reduction_size [[buffer(2)]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint lsize [[threads_per_threadgroup]], \ - uint tid [[threadgroup_position_in_grid]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_per_group [[simdgroups_per_threadgroup]], \ +#define instantiate_row_reduce_general(name, itype, otype, op) \ + template [[host_name("row_reduce_general_" #name)]] \ + [[kernel]] void row_reduce_general( \ + const device itype *in [[buffer(0)]], \ + device mlx_atomic *out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& out_size [[buffer(3)]], \ + const constant int* shape [[buffer(4)]], \ + const constant size_t* strides [[buffer(5)]], \ + const constant int& ndim [[buffer(6)]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 lsize [[threads_per_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_per_group [[simdgroups_per_threadgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]]); @@ -311,148 +260,57 @@ inline void _contiguous_strided_reduce( } template -[[kernel]] void col_reduce( +[[kernel]] void col_reduce_general( const device T *in [[buffer(0)]], device mlx_atomic *out [[buffer(1)]], const constant size_t& reduction_size [[buffer(2)]], const constant size_t& reduction_stride [[buffer(3)]], const constant size_t& out_size [[buffer(4)]], + const constant int* shape [[buffer(5)]], + const constant size_t* strides [[buffer(6)]], + const constant int& ndim [[buffer(7)]], threadgroup U *local_data [[threadgroup(0)]], - uint2 tid [[threadgroup_position_in_grid]], - uint2 lid [[thread_position_in_threadgroup]], - uint2 lsize [[threads_per_threadgroup]]) { - auto out_idx = tid.x * lsize.x + lid.x; - + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]]) { + auto out_idx = tid.x * lsize.x + lid.x; + auto in_idx = elem_to_loc( + out_idx + tid.z * out_size, + shape, + strides, + ndim + ); + if(out_idx < out_size) { _contiguous_strided_reduce( in, out, local_data, - out_idx, + in_idx, out_idx, reduction_size, reduction_stride, - tid, - lid, - lsize); + tid.xy, + lid.xy, + lsize.xy); } } -#define instantiate_col_reduce(name, itype, otype, op) \ - template [[host_name("col_reduce_" #name)]] \ - [[kernel]] void col_reduce( \ +#define instantiate_col_reduce_general(name, itype, otype, op) \ + template [[host_name("col_reduce_general_" #name)]] \ + [[kernel]] void col_reduce_general( \ const device itype *in [[buffer(0)]], \ device mlx_atomic *out [[buffer(1)]], \ const constant size_t& reduction_size [[buffer(2)]], \ const constant size_t& reduction_stride [[buffer(3)]], \ const constant size_t& out_size [[buffer(4)]], \ + const constant int* shape [[buffer(5)]], \ + const constant size_t* strides [[buffer(6)]], \ + const constant int& ndim [[buffer(7)]], \ threadgroup otype *local_data [[threadgroup(0)]], \ - uint2 tid [[threadgroup_position_in_grid]], \ - uint2 lid [[thread_position_in_threadgroup]], \ - uint2 lsize [[threads_per_threadgroup]]); - -template -[[kernel]] void contiguous_strided_reduce( - const device T *in [[buffer(0)]], - device mlx_atomic *out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& reduction_stride [[buffer(3)]], - const constant size_t& out_size [[buffer(4)]], - const device int* in_shape [[buffer(5)]], - const device size_t* in_strides [[buffer(6)]], - threadgroup U *local_data [[threadgroup(0)]], - uint2 tid [[threadgroup_position_in_grid]], - uint2 lid [[thread_position_in_threadgroup]], - uint2 lsize [[threads_per_threadgroup]]) { - - auto out_idx = tid.x * lsize.x + lid.x; - auto in_idx = elem_to_loc_nd(out_idx, in_shape, in_strides); - - if(out_idx < out_size) { - _contiguous_strided_reduce( - in, - out, - local_data, - in_idx, - out_idx, - reduction_size, - reduction_stride, - tid, - lid, - lsize); - } -} - -template -[[kernel]] void contiguous_strided_reduce( - const device T *in [[buffer(0)]], - device mlx_atomic *out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& reduction_stride [[buffer(3)]], - const constant size_t& out_size [[buffer(4)]], - const device int* in_shape [[buffer(5)]], - const device size_t* in_strides [[buffer(6)]], - const device size_t& in_dim [[buffer(7)]], - threadgroup U *local_data [[threadgroup(0)]], - uint2 tid [[threadgroup_position_in_grid]], - uint2 lid [[thread_position_in_threadgroup]], - uint2 lsize [[threads_per_threadgroup]]) { - - auto out_idx = tid.x * lsize.x + lid.x; - auto in_idx = elem_to_loc(out_idx, in_shape, in_strides, in_dim); - - if(out_idx < out_size) { - _contiguous_strided_reduce( - in, - out, - local_data, - in_idx, - out_idx, - reduction_size, - reduction_stride, - tid, - lid, - lsize); - } -} - -#define instantiate_contiguous_strided_helper(name, itype, otype, op) \ - template [[host_name("contiguous_strided_reduce_" #name)]] \ - [[kernel]] void contiguous_strided_reduce( \ - const device itype *in [[buffer(0)]], \ - device mlx_atomic *out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& reduction_stride [[buffer(3)]], \ - const constant size_t& out_size [[buffer(4)]], \ - const device int* in_shape [[buffer(5)]], \ - const device size_t* in_strides [[buffer(6)]], \ - const device size_t& in_dim [[buffer(7)]], \ - threadgroup otype *local_data [[threadgroup(0)]], \ - uint2 tid [[threadgroup_position_in_grid]], \ - uint2 lid [[thread_position_in_threadgroup]], \ - uint2 lsize [[threads_per_threadgroup]]); - -#define instantiate_contiguous_strided_helper_nd(name, itype, otype, op, n) \ - template [[host_name("contiguous_strided_reduce_" #name "_dim_" #n)]] \ - [[kernel]] void contiguous_strided_reduce( \ - const device itype *in [[buffer(0)]], \ - device mlx_atomic *out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& reduction_stride [[buffer(3)]], \ - const constant size_t& out_size [[buffer(4)]], \ - const device int* in_shape [[buffer(5)]], \ - const device size_t* in_strides [[buffer(6)]], \ - threadgroup otype *local_data [[threadgroup(0)]], \ - uint2 tid [[threadgroup_position_in_grid]], \ - uint2 lid [[thread_position_in_threadgroup]], \ - uint2 lsize [[threads_per_threadgroup]]); - -#define instantiate_contiguous_strided(name, itype, otype, op) \ - instantiate_contiguous_strided_helper(name, itype, otype, op) \ - instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 1) \ - instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 2) \ - instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 3) \ - instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 4) + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 lsize [[threads_per_threadgroup]]); /////////////////////////////////////////////////////////////////////////////// @@ -461,10 +319,8 @@ template #define instantiate_reduce(name, itype, otype, op) \ instantiate_all_reduce(name, itype, otype, op) \ - instantiate_row_reduce(name, itype, otype, op) \ - instantiate_col_reduce(name, itype, otype, op) \ - instantiate_contiguous_strided(name, itype, otype, op) \ - instantiate_general_reduce(name, itype, otype, op) + instantiate_row_reduce_general(name, itype, otype, op) \ + instantiate_col_reduce_general(name, itype, otype, op) #define instantiate_same_reduce(name, tname, type, op) \ instantiate_init_reduce(name ##tname, type, op) \ @@ -535,4 +391,4 @@ instantiate_same_reduce(max_, float16, half, Max) instantiate_same_reduce(max_, float32, float, Max) instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min) -instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max) \ No newline at end of file +instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max) diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 532f18353..6a2ce084b 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -2,9 +2,11 @@ #include #include +#include #include #include "mlx/backend/common/reduce.h" +#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/utils.h" @@ -61,22 +63,47 @@ void all_reduce_dispatch( compute_encoder->dispatchThreads(grid_dims, group_dims); } -void row_reduce_dispatch( +void row_reduce_general_dispatch( const array& in, array& out, const std::string& op_name, - const std::vector& axes_, + const ReductionPlan& plan, + const std::vector& axes, MTL::ComputeCommandEncoder* compute_encoder, metal::Device& d) { - auto kernel = d.get_kernel("row_reduce_" + op_name + type_to_name(in)); + auto kernel = + d.get_kernel("row_reduce_general_" + op_name + type_to_name(in)); + // Prepare the arguments for the kernel int n_reads = REDUCE_N_READS; - size_t reduction_size = in.size() / out.size(); + size_t reduction_size = plan.shape.back(); + size_t out_size = out.size(); + auto shape = plan.shape; + auto strides = plan.strides; + shape.pop_back(); + strides.pop_back(); + size_t non_row_reductions = 1; + for (auto s : shape) { + non_row_reductions *= static_cast(s); + } + auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes); + for (auto s : rem_shape) { + shape.push_back(s); + } + for (auto s : rem_strides) { + strides.push_back(s); + } + int ndim = shape.size(); + // Set the arguments for the kernel compute_encoder->setComputePipelineState(kernel); set_array_buffer(compute_encoder, in, 0); set_array_buffer(compute_encoder, out, 1); compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); + compute_encoder->setBytes(&out_size, sizeof(size_t), 3); + compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4); + compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 5); + compute_encoder->setBytes(&ndim, sizeof(int), 6); // Each thread group is responsible for 1 output NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); @@ -91,92 +118,54 @@ void row_reduce_dispatch( // Launch enough thread groups for each output size_t n_threads = out.size() * thread_group_size; - MTL::Size grid_dims = MTL::Size(n_threads, 1, 1); + MTL::Size grid_dims = MTL::Size(n_threads, non_row_reductions, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); compute_encoder->dispatchThreads(grid_dims, group_dims); } -void col_reduce_dispatch( +void strided_reduce_general_dispatch( const array& in, array& out, const std::string& op_name, - const std::vector& axes_, + const ReductionPlan& plan, + const std::vector& axes, MTL::ComputeCommandEncoder* compute_encoder, metal::Device& d) { - std::ostringstream kernel_name; + auto kernel = + d.get_kernel("col_reduce_general_" + op_name + type_to_name(in)); - bool encode_in_shape = false; - bool encode_ndim = false; - - // If the slowest moving axis can be merged into the reductions, - // we call the column reduce kernel - // In this case, a linear index in the output corresponds to the - // linear index in the input where the reduction starts - if (axes_[axes_.size() - 1] == (axes_.size() - 1)) { - kernel_name << "col_reduce_" << op_name << type_to_name(in); - } - // Otherwise, while all the reduction axes can be merged, the mapping between - // indices in the output and input require resolving using shapes and strides - else { - kernel_name << "contiguous_strided_reduce_" << op_name << type_to_name(in); - encode_in_shape = true; - - // We check for a viable template with the required number of dimensions - // we only care about encoding non-reduced shapes and strides in the input - size_t non_reducing_dims = in.ndim() - axes_.size(); - if (non_reducing_dims >= 1 && - non_reducing_dims <= MAX_REDUCE_SPECIALIZED_DIMS) { - kernel_name << "_dim_" << non_reducing_dims; - } else { - encode_ndim = true; - } - } - - auto kernel = d.get_kernel(kernel_name.str()); - size_t in_size = in.size(); + // Prepare the arguments for the kernel + size_t reduction_size = plan.shape.back(); + size_t reduction_stride = plan.strides.back(); size_t out_size = out.size(); + auto shape = plan.shape; + auto strides = plan.strides; + shape.pop_back(); + strides.pop_back(); + size_t non_col_reductions = 1; + for (auto s : shape) { + non_col_reductions *= static_cast(s); + } + auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes); + for (auto s : rem_shape) { + shape.push_back(s); + } + for (auto s : rem_strides) { + strides.push_back(s); + } + int ndim = shape.size(); + // Set the arguments for the kernel compute_encoder->setComputePipelineState(kernel); set_array_buffer(compute_encoder, in, 0); set_array_buffer(compute_encoder, out, 1); - - // Calculate the number of inputs to reduce and the stride b/w them - size_t reduction_size = 1; - size_t in_ndim = in.ndim(); - size_t reduction_stride = in_size; - - for (int i : axes_) { - reduction_size *= in.shape(i); - reduction_stride = std::min(reduction_stride, in.strides()[i]); - } - compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3); compute_encoder->setBytes(&out_size, sizeof(size_t), 4); - if (encode_in_shape) { - // Obtain the non-reducing shape and strides of the input to encode - std::vector inp_shape_mod; - std::vector inp_strides_mod; - - for (size_t i = 0, j = 0; i < in.ndim(); i++) { - if (j < axes_.size() && axes_[j] == i) { - j++; - } else { - inp_shape_mod.push_back(in.shape(i)); - inp_strides_mod.push_back(in.strides()[i]); - } - } - - size_t ndim = inp_shape_mod.size(); - - compute_encoder->setBytes(inp_shape_mod.data(), ndim * sizeof(int), 5); - compute_encoder->setBytes(inp_strides_mod.data(), ndim * sizeof(size_t), 6); - - if (encode_ndim) { - compute_encoder->setBytes(&ndim, sizeof(size_t), 7); - } - } + compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5); + compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 6); + compute_encoder->setBytes(&ndim, sizeof(int), 7); // Select block dimensions @@ -200,7 +189,8 @@ void col_reduce_dispatch( (n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y; // Launch enough thread groups for each output - MTL::Size grid_dims = MTL::Size(n_threadgroups_x, n_threadgroups_y, 1); + MTL::Size grid_dims = + MTL::Size(n_threadgroups_x, n_threadgroups_y, non_col_reductions); MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1); // We set shared memory to be exploited here for reductions within a @@ -216,60 +206,6 @@ void col_reduce_dispatch( compute_encoder->dispatchThreadgroups(grid_dims, group_dims); } -void general_reduce_dispatch( - const array& in, - array& out, - const std::string& op_name, - const std::vector& axes_, - MTL::ComputeCommandEncoder* compute_encoder, - metal::Device& d) { - bool encode_ndim = true; - std::ostringstream kernel_name; - kernel_name << "general_reduce_" << op_name << type_to_name(in); - - // Check for specialzed kernels for input ndim - if (in.ndim() >= 1 && in.ndim() <= MAX_REDUCE_SPECIALIZED_DIMS) { - kernel_name << "_dim_" << in.ndim(); - encode_ndim = false; - } - auto kernel = d.get_kernel(kernel_name.str()); - size_t in_size = in.size(); - size_t ndim = in.ndim(); - - // We set the reducing strides to 0 to induce collisions for the reduction - std::vector out_strides(ndim); - size_t stride = 1; - for (int i = ndim - 1, j = axes_.size() - 1; i >= 0; --i) { - if (j >= 0 && axes_[j] == i) { - out_strides[i] = 0; - --j; - } else { - out_strides[i] = stride; - stride *= in.shape(i); - } - } - - compute_encoder->setComputePipelineState(kernel); - set_array_buffer(compute_encoder, in, 0); - set_array_buffer(compute_encoder, out, 1); - compute_encoder->setBytes(in.shape().data(), ndim * sizeof(int), 2); - compute_encoder->setBytes(in.strides().data(), ndim * sizeof(size_t), 3); - compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4); - if (encode_ndim) { - compute_encoder->setBytes(&ndim, sizeof(size_t), 5); - } - - NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - if (thread_group_size > in_size) { - thread_group_size = in_size; - } - size_t nthreads = in_size; - - MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); - compute_encoder->dispatchThreads(grid_dims, group_dims); -} - } // namespace ////////////////////////////////////////////////////////////////////// @@ -278,7 +214,7 @@ void general_reduce_dispatch( void Reduce::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); - auto& in = inputs[0]; + array in = inputs[0]; // TODO: Allow specific row and column reductions with types disabled // due to atomics ? @@ -335,36 +271,46 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { // Reduce { - // Check for contiguous data - if (in.size() == in.data_size() && - (in.flags().row_contiguous || in.flags().col_contiguous)) { - // Go to all reduce if reducing over all axes - if (axes_.size() == in.ndim()) { - all_reduce_dispatch(in, out, op_name, compute_encoder, d); - return; - } - // Use specialized kernels if the input is row contiguous and - // the reducing axes can be merged into one - else if ( - in.flags().row_contiguous && in.strides().back() == 1 && - (axes_.back() - axes_.front()) == axes_.size() - 1) { - // If the fastest moving axis is being reduced, go to row reduce - if (axes_[0] == (in.ndim() - axes_.size())) { - row_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d); - return; - } - // Otherwise go to to generalized strided reduce - // Note: bool isn't support here yet due to the use of atomics - // once that is updated, this should be the else condition of this - // branch - else if (in.dtype() != bool_) { - col_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d); - return; - } - } + std::vector copies; + ReductionPlan plan = get_reduction_plan(in, axes_); + + // If it is a general reduce then copy the input to a contiguous array and + // recompute the plan. + if (plan.type == GeneralReduce) { + array in_copy(in.shape(), in.dtype(), nullptr, {}); + copy_gpu(in, in_copy, CopyType::General, s); + copies.push_back(in_copy); + in = in_copy; + plan = get_reduction_plan(in, axes_); + } + + // Reducing over everything and the data is all there no broadcasting or + // slicing etc. + if (plan.type == ContiguousAllReduce) { + all_reduce_dispatch(in, out, op_name, compute_encoder, d); + } + + // At least the last dimension is row contiguous and we are reducing over + // the last dim. + else if ( + plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) { + row_reduce_general_dispatch( + in, out, op_name, plan, axes_, compute_encoder, d); + } + + // At least the last two dimensions are contiguous and we are doing a + // strided reduce over these. + else if ( + plan.type == ContiguousStridedReduce || + plan.type == GeneralStridedReduce) { + strided_reduce_general_dispatch( + in, out, op_name, plan, axes_, compute_encoder, d); + } + + if (!copies.empty()) { + d.get_command_buffer(s.index)->addCompletedHandler( + [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); } - // Fall back to the general case - general_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d); } }