From fcc5ac1c64256f9de0d562e8234eb0c5c163fdee Mon Sep 17 00:00:00 2001 From: Vijay Krish Date: Wed, 31 Jan 2024 11:18:04 -0800 Subject: [PATCH] Add GPU support for uint64/int64 reductions (#569) --- mlx/backend/metal/kernels/arg_reduce.metal | 12 - mlx/backend/metal/kernels/reduce.metal | 459 +++++++++++++++------ mlx/backend/metal/kernels/utils.h | 18 + mlx/backend/metal/reduce.cpp | 355 +++++++++++++--- python/tests/test_reduce.py | 2 + 5 files changed, 654 insertions(+), 192 deletions(-) diff --git a/mlx/backend/metal/kernels/arg_reduce.metal b/mlx/backend/metal/kernels/arg_reduce.metal index 467e768d6..f153b920d 100644 --- a/mlx/backend/metal/kernels/arg_reduce.metal +++ b/mlx/backend/metal/kernels/arg_reduce.metal @@ -63,18 +63,6 @@ struct ArgMax { } }; -bool simd_shuffle_down(bool data, uint16_t delta) { - return simd_shuffle_down(static_cast(data), delta); -} - -uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) { - return as_type(simd_shuffle_down(as_type(data), delta)); -} - -int64_t simd_shuffle_down(int64_t data, uint16_t delta) { - return as_type(simd_shuffle_down(as_type(data), delta)); -} - template IndexValPair simd_shuffle_down(IndexValPair data, uint16_t delta) { return IndexValPair( diff --git a/mlx/backend/metal/kernels/reduce.metal b/mlx/backend/metal/kernels/reduce.metal index 4182184c2..ee00f48ff 100644 --- a/mlx/backend/metal/kernels/reduce.metal +++ b/mlx/backend/metal/kernels/reduce.metal @@ -24,11 +24,59 @@ template device otype *out [[buffer(1)]], \ uint tid [[thread_position_in_grid]]); - /////////////////////////////////////////////////////////////////////////////// // All reduce /////////////////////////////////////////////////////////////////////////////// +template +inline U per_thread_all_reduce( + const device T *in, + const device size_t& in_size, + uint gid, + uint grid_size) { + Op op; + U total_val = Op::init; + + if (gid * N_READS < in_size) { + in += gid * N_READS; + + int r = 0; + for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) { + U vals[N_READS] = {op.init}; + + for(int i = 0; i < N_READS; i++) { + vals[i] = static_cast(in[i]); + } + for(int i = 0; i < N_READS; i++) { + total_val = op(vals[i], total_val); + } + + in += grid_size * N_READS; + } + + // Separate case for the last set as we close the reduction size + size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS; + if (curr_idx < in_size) { + int max_reads = in_size - curr_idx; + T vals[N_READS]; + + for(int i = 0, idx = 0; i < N_READS; i++, idx++) { + idx = idx < max_reads ? idx : max_reads - 1; + vals[i] = in[idx]; + } + for(int i = 0; i < N_READS; i++) { + U val = i < max_reads ? vals[i] : Op::init; + total_val = op(static_cast(val), total_val); + } + } + } + + return total_val; +} + +// NB: This kernel assumes threads_per_threadgroup is at most +// 1024. This way with a simd_size of 32, we are guaranteed to +// complete the reduction in two steps of simd-level reductions. template [[kernel]] void all_reduce( const device T *in [[buffer(0)]], @@ -40,53 +88,18 @@ template uint simd_per_group [[simdgroups_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - // NB: this kernel assumes threads_per_threadgroup is at most - // 1024. This way with a simd_size of 32, we are guaranteed to - // complete the reduction in two steps of simd-level reductions. Op op; threadgroup U local_vals[simd_size]; - - U total_val = Op::init; - in += gid * N_READS; - - int r = 0; - for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) { - U vals[N_READS] = {op.init}; - - for(int i = 0; i < N_READS; i++) { - vals[i] = static_cast(in[i]); - } - for(int i = 0; i < N_READS; i++) { - total_val = op(vals[i], total_val); - } - - in += grid_size * N_READS; - } - - // Separate case for the last set as we close the reduction size - size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS; - if (curr_idx < in_size) { - int max_reads = in_size - curr_idx; - T vals[N_READS]; - - for(int i = 0, idx = 0; i < N_READS; i++, idx++) { - idx = idx < max_reads ? idx : max_reads - 1; - vals[i] = in[idx]; - } - for(int i = 0; i < N_READS; i++) { - U val = i < max_reads ? vals[i] : Op::init; - total_val = op(static_cast(val), total_val); - } - } + U total_val = per_thread_all_reduce(in, in_size, gid, grid_size); // Reduction within simd group total_val = op.simd_reduce(total_val); if (simd_lane_id == 0) { local_vals[simd_group_id] = total_val; } - + // Reduction within thread group threadgroup_barrier(mem_flags::mem_threadgroup); total_val = lid < simd_per_group ? local_vals[lid] : op.init; @@ -98,6 +111,46 @@ template } } +template +[[kernel]] void all_reduce_no_atomics( + const device T *in [[buffer(0)]], + device U *out [[buffer(1)]], + const device size_t& in_size [[buffer(2)]], + uint gid [[thread_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint grid_size [[threads_per_grid]], + uint simd_per_group [[simdgroups_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint thread_group_id [[threadgroup_position_in_grid]]) { + + Op op; + threadgroup U local_vals[simd_size]; + + U total_val = per_thread_all_reduce(in, in_size, gid, grid_size); + + // Reduction within simd group (simd_add isn't supported for uint64/int64 types) + for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) { + total_val = op(total_val, simd_shuffle_down(total_val, lane_offset)); + } + // Write simd group reduction results to local memory + if (simd_lane_id == 0) { + local_vals[simd_group_id] = total_val; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Reduction of simdgroup reduction results within threadgroup. + total_val = lid < simd_per_group ? local_vals[lid] : op.init; + for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) { + total_val = op(total_val, simd_shuffle_down(total_val, lane_offset)); + } + + // Reduction across threadgroups + if (lid == 0) { + out[thread_group_id] = total_val; + } +} + #define instantiate_all_reduce(name, itype, otype, op) \ template [[host_name("all_reduce_" #name)]] \ [[kernel]] void all_reduce( \ @@ -111,11 +164,80 @@ template uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]]); +#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \ + template [[host_name("all_reduce_no_atomics_" #name)]] \ + [[kernel]] void all_reduce_no_atomics( \ + const device itype *in [[buffer(0)]], \ + device otype *out [[buffer(1)]], \ + const device size_t& in_size [[buffer(2)]], \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint grid_size [[threads_per_grid]], \ + uint simd_per_group [[simdgroups_per_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint thread_group_id [[threadgroup_position_in_grid]]); /////////////////////////////////////////////////////////////////////////////// // Row atomics /////////////////////////////////////////////////////////////////////////////// +template +inline U per_thread_row_reduce( + const device T *in, + const constant size_t& reduction_size, + const constant size_t& out_size, + const constant int* shape, + const constant size_t* strides, + const constant int& ndim, + uint lsize_x, + uint lid_x, + uint2 tid) { + + Op op; + + // Each threadgroup handles 1 reduction + // TODO: Specializing elem_to_loc would be slightly faster + int idx = tid.y * out_size + tid.x; + int extra_offset = elem_to_loc(idx, shape, strides, ndim); + in += extra_offset + lid_x * N_READS; + + // The reduction is accumulated here + U total_val = Op::init; + + // Loop over the reduction size within thread group + int r = 0; + for (; r < (int)ceildiv(reduction_size, N_READS*lsize_x) - 1; r++) { + T vals[N_READS]; + for(int i = 0; i < N_READS; i++) { + vals[i] = in[i]; + } + for(int i = 0; i < N_READS; i++) { + total_val = op(static_cast(vals[i]), total_val); + } + + in += lsize_x * N_READS; + } + + // Separate case for the last set as we close the reduction size + size_t reduction_index = (lid_x + (size_t)lsize_x * r) * N_READS; + if(reduction_index < reduction_size) { + int max_reads = reduction_size - reduction_index; + + T vals[N_READS]; + for(int i = 0; i < N_READS; i++) { + int idx = min(i, max_reads - 1); + vals[i] = static_cast(in[idx]); + } + for(int i = 0; i < N_READS; i++) { + T val = i < max_reads ? vals[i] : Op::init; + total_val = op(static_cast(val), total_val); + } + } + + return total_val; +} + template [[kernel]] void row_reduce_general( const device T *in [[buffer(0)]], @@ -133,46 +255,9 @@ template uint simd_group_id [[simdgroup_index_in_threadgroup]]) { Op op; - - // Each threadgroup handles 1 reduction - // TODO: Specializing elem_to_loc would be slightly faster - int idx = tid.y * out_size + tid.x; - int extra_offset = elem_to_loc(idx, shape, strides, ndim); - in += extra_offset + lid.x * N_READS; - - // The reduction is accumulated here - U total_val = Op::init; threadgroup U local_vals[simd_size]; - // Loop over the reduction size within thread group - int r = 0; - for (; r < (int)ceildiv(reduction_size, N_READS*lsize.x) - 1; r++) { - T vals[N_READS]; - for(int i = 0; i < N_READS; i++) { - vals[i] = in[i]; - } - for(int i = 0; i < N_READS; i++) { - total_val = op(static_cast(vals[i]), total_val); - } - - in += lsize.x * N_READS; - } - - // Separate case for the last set as we close the reduction size - size_t reduction_index = (lid.x + (size_t)lsize.x * r) * N_READS; - if(reduction_index < reduction_size) { - int max_reads = reduction_size - reduction_index; - - T vals[N_READS]; - for(int i = 0; i < N_READS; i++) { - int idx = min(i, max_reads - 1); - vals[i] = static_cast(in[idx]); - } - for(int i = 0; i < N_READS; i++) { - T val = i < max_reads ? vals[i] : Op::init; - total_val = op(static_cast(val), total_val); - } - } + U total_val = per_thread_row_reduce(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy); total_val = op.simd_reduce(total_val); @@ -194,6 +279,53 @@ template } } +template +[[kernel]] void row_reduce_general_no_atomics( + const device T *in [[buffer(0)]], + device U *out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& out_size [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant size_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], + uint3 gsize [[threads_per_grid]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_per_group [[simdgroups_per_threadgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + + Op op; + + threadgroup U local_vals[simd_size]; + U total_val = per_thread_row_reduce(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy); + + // Reduction within simd group - simd_add isn't supported for int64 types + for (uint16_t i = simd_size/2; i > 0; i /= 2) { + total_val = op(total_val, simd_shuffle_down(total_val, i)); + } + + // Prepare next level + if (simd_lane_id == 0) { + local_vals[simd_group_id] = total_val; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Reduction within thread group + // Only needed if thread group has multiple simd groups + if(ceildiv(reduction_size, N_READS) > simd_size) { + total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init; + for (uint16_t i = simd_size/2; i > 0; i /= 2) { + total_val = op(total_val, simd_shuffle_down(total_val, i)); + } + } + // Write row reduce output for threadgroup with 1st thread in thread group + if (lid.x == 0) { + out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val; + } +} + #define instantiate_row_reduce_general(name, itype, otype, op) \ template [[host_name("row_reduce_general_" #name)]] \ [[kernel]] void row_reduce_general( \ @@ -211,52 +343,59 @@ template uint simd_per_group [[simdgroups_per_threadgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]]); +#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \ + template [[host_name("row_reduce_general_no_atomics_" #name)]] \ + [[kernel]] void row_reduce_general_no_atomics( \ + const device itype *in [[buffer(0)]], \ + device otype *out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& out_size [[buffer(3)]], \ + const constant int* shape [[buffer(4)]], \ + const constant size_t* strides [[buffer(5)]], \ + const constant int& ndim [[buffer(6)]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 lsize [[threads_per_threadgroup]], \ + uint3 gsize [[threads_per_grid]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_per_group [[simdgroups_per_threadgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); /////////////////////////////////////////////////////////////////////////////// // Column reduce /////////////////////////////////////////////////////////////////////////////// template -inline void _contiguous_strided_reduce( - const device T *in, - device mlx_atomic *out, - threadgroup U *local_data, - uint in_idx, - uint out_idx, - uint reduction_size, - uint reduction_stride, - uint2 tid, - uint2 lid, +inline U _contiguous_strided_reduce( + const device T *in, + threadgroup U *local_data, + uint in_idx, + uint reduction_size, + uint reduction_stride, + uint2 tid, + uint2 lid, uint2 lsize) { Op op; - T local_vals[N_READS]; + U total_val = Op::init; uint base_offset = (tid.y * lsize.y + lid.y) * N_READS; - - for(uint r = 0; r < N_READS; r++) { - uint offset = base_offset + r; - offset = offset < reduction_size ? offset : reduction_size - 1; - local_vals[r] = in[in_idx + offset * reduction_stride]; - } - - U total_val = Op::init; for(uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) { - total_val = op(static_cast(total_val), local_vals[r]); + uint offset = base_offset + r; + total_val = op(static_cast(total_val), in[in_idx + offset * reduction_stride]); } - local_data[lsize.y * lid.x + lid.y] = total_val; - + local_data[lsize.y * lid.x + lid.y] = total_val; threadgroup_barrier(mem_flags::mem_threadgroup); + U val = Op::init; if(lid.y == 0) { - U val = op.init; - + // Perform reduction across columns in thread group for(uint i = 0; i < lsize.y; i++) { - val = op(val, local_data[lsize.y * lid.x + i]); + val = op(val, local_data[lsize.y * lid.x + i]); } - - op.atomic_update(out, val, out_idx); } + + return val; } template @@ -265,13 +404,13 @@ template device mlx_atomic *out [[buffer(1)]], const constant size_t& reduction_size [[buffer(2)]], const constant size_t& reduction_stride [[buffer(3)]], - const constant size_t& out_size [[buffer(4)]], + const constant size_t& out_size [[buffer(4)]], const constant int* shape [[buffer(5)]], const constant size_t* strides [[buffer(6)]], const constant int& ndim [[buffer(7)]], threadgroup U *local_data [[threadgroup(0)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], uint3 lsize [[threads_per_threadgroup]]) { auto out_idx = tid.x * lsize.x + lid.x; auto in_idx = elem_to_loc( @@ -281,18 +420,66 @@ template ndim ); + Op op; if(out_idx < out_size) { - _contiguous_strided_reduce( - in, - out, - local_data, - in_idx, - out_idx, - reduction_size, - reduction_stride, - tid.xy, - lid.xy, - lsize.xy); + U val = _contiguous_strided_reduce( + in, + local_data, + in_idx, + reduction_size, + reduction_stride, + tid.xy, + lid.xy, + lsize.xy); + + // Write out reduction results generated by threadgroups working on specific output element, contiguously. + if (lid.y == 0) { + op.atomic_update(out, val, out_idx); + } + } +} + +template +[[kernel]] void col_reduce_general_no_atomics( + const device T *in [[buffer(0)]], + device U *out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& reduction_stride [[buffer(3)]], + const constant size_t& out_size [[buffer(4)]], + const constant int* shape [[buffer(5)]], + const constant size_t* strides [[buffer(6)]], + const constant int& ndim [[buffer(7)]], + threadgroup U *local_data [[threadgroup(0)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 gid [[thread_position_in_grid]], + uint3 lsize [[threads_per_threadgroup]], + uint3 gsize [[threads_per_grid]]) { + auto out_idx = tid.x * lsize.x + lid.x; + auto in_idx = elem_to_loc( + out_idx + tid.z * out_size, + shape, + strides, + ndim + ); + + if(out_idx < out_size) { + U val = _contiguous_strided_reduce( + in, + local_data, + in_idx, + reduction_size, + reduction_stride, + tid.xy, + lid.xy, + lsize.xy); + + // Write out reduction results generated by threadgroups working on specific output element, contiguously. + if (lid.y == 0) { + uint tgsize_y = ceildiv(gsize.y, lsize.y); + uint tgsize_z = ceildiv(gsize.z, lsize.z); + out[tgsize_y * tgsize_z * gid.x + tgsize_y * tid.z + tid.y] = val; + } } } @@ -312,6 +499,23 @@ template uint3 lid [[thread_position_in_threadgroup]], \ uint3 lsize [[threads_per_threadgroup]]); +#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \ + template [[host_name("col_reduce_general_no_atomics_" #name)]] \ + [[kernel]] void col_reduce_general_no_atomics( \ + const device itype *in [[buffer(0)]], \ + device otype *out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& reduction_stride [[buffer(3)]], \ + const constant size_t& out_size [[buffer(4)]], \ + const constant int* shape [[buffer(5)]], \ + const constant size_t* strides [[buffer(6)]], \ + const constant int& ndim [[buffer(7)]], \ + threadgroup otype *local_data [[threadgroup(0)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 gid [[thread_position_in_grid]], \ + uint3 lsize [[threads_per_threadgroup]], \ + uint3 gsize [[threads_per_grid]]); /////////////////////////////////////////////////////////////////////////////// // Instantiations @@ -322,6 +526,15 @@ template instantiate_row_reduce_general(name, itype, otype, op) \ instantiate_col_reduce_general(name, itype, otype, op) +#define instantiate_reduce_no_atomics(name, itype, otype, op) \ + instantiate_all_reduce_no_atomics(name, itype, otype, op) \ + instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \ + instantiate_col_reduce_general_no_atomics(name, itype, otype, op) + +#define instantiate_same_reduce_no_atomics(name, tname, type, op) \ + instantiate_init_reduce(name ##tname, type, op) \ + instantiate_reduce_no_atomics(name ##tname, type, type, op) + #define instantiate_same_reduce(name, tname, type, op) \ instantiate_init_reduce(name ##tname, type, op) \ instantiate_reduce(name ##tname, type, type, op) @@ -353,6 +566,9 @@ instantiate_same_reduce(sum, int32, int32_t, Sum) instantiate_same_reduce(sum, float16, half, Sum) instantiate_same_reduce(sum, float32, float, Sum) +instantiate_same_reduce_no_atomics(sum, int64, int64_t, Sum) +instantiate_same_reduce_no_atomics(sum, uint64, uint64_t, Sum) + instantiate_same_reduce(prod, uint8, uint8_t, Prod) instantiate_same_reduce(prod, uint16, uint16_t, Prod) instantiate_same_reduce(prod, uint32, uint32_t, Prod) @@ -362,6 +578,9 @@ instantiate_same_reduce(prod, int32, int32_t, Prod) instantiate_same_reduce(prod, float16, half, Prod) instantiate_same_reduce(prod, float32, float, Prod) +instantiate_same_reduce_no_atomics(prod, int64, int64_t, Prod) +instantiate_same_reduce_no_atomics(prod, uint64, uint64_t, Prod) + instantiate_same_reduce(sum, bfloat16, bfloat16_t, Sum) instantiate_same_reduce(prod, bfloat16, bfloat16_t, Prod) @@ -381,6 +600,9 @@ instantiate_same_reduce(min_, int32, int32_t, Min) instantiate_same_reduce(min_, float16, half, Min) instantiate_same_reduce(min_, float32, float, Min) +instantiate_same_reduce_no_atomics(min_, int64, int64_t, Min) +instantiate_same_reduce_no_atomics(min_, uint64, uint64_t, Min) + instantiate_same_reduce(max_, uint8, uint8_t, Max) instantiate_same_reduce(max_, uint16, uint16_t, Max) instantiate_same_reduce(max_, uint32, uint32_t, Max) @@ -390,5 +612,8 @@ instantiate_same_reduce(max_, int32, int32_t, Max) instantiate_same_reduce(max_, float16, half, Max) instantiate_same_reduce(max_, float32, float, Max) +instantiate_same_reduce_no_atomics(max_, int64, int64_t, Max) +instantiate_same_reduce_no_atomics(max_, uint64, uint64_t, Max) + instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min) instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max) diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index ce6dc2954..634a9d6df 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -256,3 +256,21 @@ inline bfloat16_t log1p(bfloat16_t x) { return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f))); } + +/////////////////////////////////////////////////////////////////////////////// +// SIMD shuffle ops +/////////////////////////////////////////////////////////////////////////////// + +inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) { + return as_type( + metal::simd_shuffle_down(as_type(data), delta)); +} + +inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) { + return as_type( + metal::simd_shuffle_down(as_type(data), delta)); +} + +inline bool simd_shuffle_down(bool data, uint16_t delta) { + return simd_shuffle_down(static_cast(data), delta); +} \ No newline at end of file diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 601ae4b61..9b7e729da 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -28,35 +28,40 @@ inline auto safe_divup(size_t n, size_t m) { return safe_div(n, m) * m; } +inline bool is_64b_int(Dtype dtype) { + return dtype == int64 || dtype == uint64; +} + // All Reduce void all_reduce_dispatch( const array& in, array& out, const std::string& op_name, MTL::ComputeCommandEncoder* compute_encoder, - metal::Device& d) { - // Get kernel and encode buffers - size_t in_size = in.size(); - auto kernel = d.get_kernel("all_reduce_" + op_name + type_to_name(in)); + metal::Device& d, + const Stream& s) { + Dtype out_dtype = out.dtype(); + bool is_out_64b_int = is_64b_int(out_dtype); + auto kernel = (is_out_64b_int) + ? d.get_kernel("all_reduce_no_atomics_" + op_name + type_to_name(in)) + : d.get_kernel("all_reduce_" + op_name + type_to_name(in)); compute_encoder->setComputePipelineState(kernel); - set_array_buffer(compute_encoder, in, 0); - set_array_buffer(compute_encoder, out, 1); - compute_encoder->setBytes(&in_size, sizeof(size_t), 2); - - // Set grid dimensions // We make sure each thread has enough to do by making it read in // at least n_reads inputs int n_reads = REDUCE_N_READS; + size_t in_size = in.size(); // mod_in_size gives us the groups of n_reads needed to go over the entire // input uint mod_in_size = (in_size + n_reads - 1) / n_reads; - NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); thread_group_size = mod_in_size > thread_group_size ? thread_group_size : mod_in_size; + uint simd_size = kernel->threadExecutionWidth(); + thread_group_size = + ((thread_group_size + simd_size - 1) / simd_size) * simd_size; // If the number of thread groups needed exceeds 1024, we reuse threads groups uint n_thread_groups = safe_div(mod_in_size, thread_group_size); @@ -66,7 +71,52 @@ void all_reduce_dispatch( MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); - compute_encoder->dispatchThreads(grid_dims, group_dims); + // Encode buffers and dispatch + if (is_out_64b_int == false || n_thread_groups == 1) { + set_array_buffer(compute_encoder, in, 0); + set_array_buffer(compute_encoder, out, 1); + compute_encoder->setBytes(&in_size, sizeof(size_t), 2); + compute_encoder->dispatchThreads(grid_dims, group_dims); + + } else { + // Allocate intermediate array to store partial reduction results + size_t intermediate_size = n_thread_groups; + array intermediate = + array({static_cast(intermediate_size)}, out_dtype, nullptr, {}); + intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); + std::vector intermediates = {intermediate}; + + // First dispatch + set_array_buffer(compute_encoder, in, 0); + set_array_buffer(compute_encoder, intermediate, 1); + compute_encoder->setBytes(&in_size, sizeof(size_t), 2); + compute_encoder->dispatchThreads(grid_dims, group_dims); + + // Second pass to reduce intermediate reduction results written to DRAM + set_array_buffer(compute_encoder, intermediate, 0); + set_array_buffer(compute_encoder, out, 1); + compute_encoder->setBytes(&intermediate_size, sizeof(size_t), 2); + + mod_in_size = (intermediate_size + n_reads - 1) / n_reads; + + thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + thread_group_size = + mod_in_size > thread_group_size ? thread_group_size : mod_in_size; + thread_group_size = + ((thread_group_size + simd_size - 1) / simd_size) * simd_size; + + // If the number of thread groups needed exceeds 1024, we reuse threads + // groups + nthreads = thread_group_size; + group_dims = MTL::Size(thread_group_size, 1, 1); + grid_dims = MTL::Size(nthreads, 1, 1); + compute_encoder->dispatchThreads(grid_dims, group_dims); + + d.get_command_buffer(s.index)->addCompletedHandler( + [intermediates](MTL::CommandBuffer*) mutable { + intermediates.clear(); + }); + } } void row_reduce_general_dispatch( @@ -76,22 +126,31 @@ void row_reduce_general_dispatch( const ReductionPlan& plan, const std::vector& axes, MTL::ComputeCommandEncoder* compute_encoder, - metal::Device& d) { - auto kernel = - d.get_kernel("row_reduce_general_" + op_name + type_to_name(in)); + metal::Device& d, + const Stream& s) { + Dtype out_dtype = out.dtype(); + bool is_out_64b_int = is_64b_int(out_dtype); + auto kernel = (is_out_64b_int) + ? d.get_kernel( + "row_reduce_general_no_atomics_" + op_name + type_to_name(in)) + : d.get_kernel("row_reduce_general_" + op_name + type_to_name(in)); + + compute_encoder->setComputePipelineState(kernel); // Prepare the arguments for the kernel int n_reads = REDUCE_N_READS; size_t reduction_size = plan.shape.back(); - size_t out_size = out.size(); auto shape = plan.shape; auto strides = plan.strides; + shape.pop_back(); strides.pop_back(); + size_t non_row_reductions = 1; for (auto s : shape) { non_row_reductions *= static_cast(s); } + size_t out_size = out.size(); auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes); for (auto s : rem_shape) { shape.push_back(s); @@ -101,16 +160,6 @@ void row_reduce_general_dispatch( } int ndim = shape.size(); - // Set the arguments for the kernel - compute_encoder->setComputePipelineState(kernel); - set_array_buffer(compute_encoder, in, 0); - set_array_buffer(compute_encoder, out, 1); - compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); - compute_encoder->setBytes(&out_size, sizeof(size_t), 3); - compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4); - compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 5); - compute_encoder->setBytes(&ndim, sizeof(int), 6); - // Each thread group is responsible for 1 output NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); thread_group_size = @@ -127,7 +176,88 @@ void row_reduce_general_dispatch( MTL::Size grid_dims = MTL::Size(n_threads, non_row_reductions, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - compute_encoder->dispatchThreads(grid_dims, group_dims); + if (is_out_64b_int == false || non_row_reductions == 1) { + // Set the arguments for the kernel + set_array_buffer(compute_encoder, in, 0); + set_array_buffer(compute_encoder, out, 1); + compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); + compute_encoder->setBytes(&out_size, sizeof(size_t), 3); + compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4); + compute_encoder->setBytes( + strides.data(), strides.size() * sizeof(size_t), 5); + compute_encoder->setBytes(&ndim, sizeof(int), 6); + compute_encoder->dispatchThreads(grid_dims, group_dims); + + } else { + // Allocate intermediate array to store partial reduction results + array intermediate = array( + {static_cast(out.size()), static_cast(non_row_reductions)}, + out_dtype, + nullptr, + {}); + intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); + std::vector intermediates = {intermediate}; + + // Set the arguments for the kernel + set_array_buffer(compute_encoder, in, 0); + set_array_buffer(compute_encoder, intermediate, 1); + compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); + compute_encoder->setBytes(&out_size, sizeof(size_t), 3); + compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4); + compute_encoder->setBytes( + strides.data(), strides.size() * sizeof(size_t), 5); + compute_encoder->setBytes(&ndim, sizeof(int), 6); + compute_encoder->dispatchThreads(grid_dims, group_dims); + + // Set up second dispatch + reduction_size = non_row_reductions; + out_size = 1; + + // Shape of axes that aren't participating in reduction remains unchanged. + std::vector new_shape = rem_shape; + + // Update their strides since they'll be different post partial reduction in + // first compute dispatch. + std::vector new_strides = rem_strides; + new_strides.back() = reduction_size; + for (int i = new_shape.size() - 2; i >= 0; i--) { + new_strides[i] = new_shape[i + 1] * new_strides[i + 1]; + } + ndim = new_shape.size(); + + // Set the arguments for the kernel + set_array_buffer(compute_encoder, intermediate, 0); + set_array_buffer(compute_encoder, out, 1); + compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); + compute_encoder->setBytes(&out_size, sizeof(size_t), 3); + compute_encoder->setBytes( + new_shape.data(), new_shape.size() * sizeof(int), 4); + compute_encoder->setBytes( + new_strides.data(), new_strides.size() * sizeof(size_t), 5); + compute_encoder->setBytes(&ndim, sizeof(int), 6); + + // Each thread group is responsible for 1 output + thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + thread_group_size = + std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size); + + // Align thread group size with simd_size + thread_group_size = + (thread_group_size + simd_size - 1) / simd_size * simd_size; + assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup()); + + // Launch enough thread groups for each output + n_threads = thread_group_size; + grid_dims = MTL::Size(n_threads, out.size(), 1); + group_dims = MTL::Size(thread_group_size, 1, 1); + + compute_encoder->dispatchThreads(grid_dims, group_dims); + + d.get_command_buffer(s.index)->addCompletedHandler( + [intermediates](MTL::CommandBuffer*) mutable { + intermediates.clear(); + }); + } } void strided_reduce_general_dispatch( @@ -137,9 +267,16 @@ void strided_reduce_general_dispatch( const ReductionPlan& plan, const std::vector& axes, MTL::ComputeCommandEncoder* compute_encoder, - metal::Device& d) { - auto kernel = - d.get_kernel("col_reduce_general_" + op_name + type_to_name(in)); + metal::Device& d, + const Stream& s) { + Dtype out_dtype = out.dtype(); + bool is_out_64b_int = is_64b_int(out_dtype); + auto kernel = (is_out_64b_int) + ? d.get_kernel( + "col_reduce_general_no_atomics_" + op_name + type_to_name(in)) + : d.get_kernel("col_reduce_general_" + op_name + type_to_name(in)); + + compute_encoder->setComputePipelineState(kernel); // Prepare the arguments for the kernel size_t reduction_size = plan.shape.back(); @@ -162,19 +299,7 @@ void strided_reduce_general_dispatch( } int ndim = shape.size(); - // Set the arguments for the kernel - compute_encoder->setComputePipelineState(kernel); - set_array_buffer(compute_encoder, in, 0); - set_array_buffer(compute_encoder, out, 1); - compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); - compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3); - compute_encoder->setBytes(&out_size, sizeof(size_t), 4); - compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5); - compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 6); - compute_encoder->setBytes(&ndim, sizeof(int), 7); - // Select block dimensions - // Each thread reads 16 inputs to give it more work uint n_inputs_per_thread = REDUCE_N_READS; uint n_threads_per_output = @@ -183,14 +308,22 @@ void strided_reduce_general_dispatch( // We spread outputs over the x dimension and inputs over the y dimension // Threads with the same lid.x in a given threadgroup work on the same // output and each thread in the y dimension accumulates for that output + + // Threads with same lid.x, i.e. each column of threads work on same output uint threadgroup_dim_x = std::min(out_size, 128ul); + + // Number of threads along y, is dependent on number of reductions needed. uint threadgroup_dim_y = kernel->maxTotalThreadsPerThreadgroup() / threadgroup_dim_x; threadgroup_dim_y = std::min(n_threads_per_output, threadgroup_dim_y); + // Derive number of thread groups along x, based on how many threads we need + // along x uint n_threadgroups_x = (out_size + threadgroup_dim_x - 1) / threadgroup_dim_x; + // Derive number of thread groups along y based on how many threads we need + // along y uint n_threadgroups_y = (n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y; @@ -199,18 +332,122 @@ void strided_reduce_general_dispatch( MTL::Size(n_threadgroups_x, n_threadgroups_y, non_col_reductions); MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1); - // We set shared memory to be exploited here for reductions within a - // threadgroup - each thread must be able to update its accumulated output - // Note: Each threadgroup should have 32kB of data in threadgroup memory - // and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design - // This should be fine for floats, but we might need to revisit - // if we ever come to doubles. In that case, we should also cut - // down the number of threads we launch in a threadgroup - compute_encoder->setThreadgroupMemoryLength( - safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16), - 0); + if (is_out_64b_int == false) { + // Set the arguments for the kernel + set_array_buffer(compute_encoder, in, 0); + set_array_buffer(compute_encoder, out, 1); + compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); + compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3); + compute_encoder->setBytes(&out_size, sizeof(size_t), 4); + compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5); + compute_encoder->setBytes( + strides.data(), strides.size() * sizeof(size_t), 6); + compute_encoder->setBytes(&ndim, sizeof(int), 7); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + // We set shared memory to be exploited here for reductions within a + // threadgroup - each thread must be able to update its accumulated output + // Note: Each threadgroup should have 32kB of data in threadgroup memory + // and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design + // This should be fine for floats, but we might need to revisit + // if we ever come to doubles. In that case, we should also cut + // down the number of threads we launch in a threadgroup + compute_encoder->setThreadgroupMemoryLength( + safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16), + 0); + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + + } else { + // Allocate intermediate array to store reduction results from all thread + // groups + array intermediate = array( + {static_cast(out.size()), + static_cast(n_threadgroups_y * non_col_reductions)}, + out_dtype, + nullptr, + {}); + intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); + std::vector intermediates = {intermediate}; + + // Set the arguments for the kernel + set_array_buffer(compute_encoder, in, 0); + set_array_buffer(compute_encoder, intermediate, 1); + compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); + compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3); + compute_encoder->setBytes(&out_size, sizeof(size_t), 4); + compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5); + compute_encoder->setBytes( + strides.data(), strides.size() * sizeof(size_t), 6); + compute_encoder->setBytes(&ndim, sizeof(int), 7); + + // We set shared memory to be exploited here for reductions within a + // threadgroup - each thread must be able to update its accumulated output + // Note: Each threadgroup should have 32kB of data in threadgroup memory + // and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design + // This should be fine for floats, but we might need to revisit + // if we ever come to doubles. In that case, we should also cut + // down the number of threads we launch in a threadgroup + compute_encoder->setThreadgroupMemoryLength( + safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16), + 0); + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + + // Perform second pass of reductions + // Reduce results of threadgroups along y, z from first pass, that + // collectively work on each output element. + reduction_size = n_threadgroups_y * non_col_reductions; + out_size = 1; + + // Shape of axes that aren't participating in reduction remains unchanged. + std::vector new_shape = rem_shape; + + // Update their strides since they'll be different after a partial reduction + // post first compute dispatch. + std::vector new_strides = rem_strides; + new_strides.back() = reduction_size; + for (int i = new_shape.size() - 2; i >= 0; i--) { + new_strides[i] = new_shape[i + 1] * new_strides[i + 1]; + } + ndim = new_shape.size(); + + auto row_reduce_kernel = d.get_kernel( + "row_reduce_general_no_atomics_" + op_name + + type_to_name(intermediate)); + compute_encoder->setComputePipelineState(row_reduce_kernel); + set_array_buffer(compute_encoder, intermediate, 0); + set_array_buffer(compute_encoder, out, 1); + compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); + compute_encoder->setBytes(&out_size, sizeof(size_t), 3); + compute_encoder->setBytes( + new_shape.data(), new_shape.size() * sizeof(int), 4); + compute_encoder->setBytes( + new_strides.data(), new_strides.size() * sizeof(size_t), 5); + compute_encoder->setBytes(&ndim, sizeof(int), 6); + + // Each thread group is responsible for 1 output + size_t n_reads = REDUCE_N_READS; + size_t thread_group_size = + row_reduce_kernel->maxTotalThreadsPerThreadgroup(); + thread_group_size = + std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size); + + // Align thread group size with simd_size + uint simd_size = row_reduce_kernel->threadExecutionWidth(); + thread_group_size = + (thread_group_size + simd_size - 1) / simd_size * simd_size; + assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup()); + + // Launch enough thread groups for each output + uint n_threads = thread_group_size; + grid_dims = MTL::Size(n_threads, out.size(), 1); + group_dims = MTL::Size(thread_group_size, 1, 1); + + compute_encoder->dispatchThreads(grid_dims, group_dims); + + d.get_command_buffer(s.index)->addCompletedHandler( + [intermediates](MTL::CommandBuffer*) mutable { + intermediates.clear(); + }); + } } } // namespace @@ -223,14 +460,6 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); array in = inputs[0]; - // TODO: Allow specific row and column reductions with types disabled - // due to atomics ? - if (size_of(in.dtype()) == 8) { - std::ostringstream msg; - msg << "[Reduce::eval_gpu] Does not support " << in.dtype(); - throw std::runtime_error(msg.str()); - } - // Make sure no identity reductions trickle down here assert(!axes_.empty()); @@ -297,7 +526,7 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { // Reducing over everything and the data is all there no broadcasting or // slicing etc. if (plan.type == ContiguousAllReduce) { - all_reduce_dispatch(in, out, op_name, compute_encoder, d); + all_reduce_dispatch(in, out, op_name, compute_encoder, d, s); } // At least the last dimension is row contiguous and we are reducing over @@ -305,7 +534,7 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { else if ( plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) { row_reduce_general_dispatch( - in, out, op_name, plan, axes_, compute_encoder, d); + in, out, op_name, plan, axes_, compute_encoder, d, s); } // At least the last two dimensions are contiguous and we are doing a @@ -314,7 +543,7 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { plan.type == ContiguousStridedReduce || plan.type == GeneralStridedReduce) { strided_reduce_general_dispatch( - in, out, op_name, plan, axes_, compute_encoder, d); + in, out, op_name, plan, axes_, compute_encoder, d, s); } if (!copies.empty()) { diff --git a/python/tests/test_reduce.py b/python/tests/test_reduce.py index 30145b83f..70c8a1172 100644 --- a/python/tests/test_reduce.py +++ b/python/tests/test_reduce.py @@ -55,6 +55,8 @@ class TestReduce(mlx_tests.MLXTestCase): "uint8", "uint16", "uint32", + "int64", + "uint64", ] float_dtypes = ["float32"]