mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Add GPU support for uint64/int64 reductions (#569)
This commit is contained in:
		| @@ -63,18 +63,6 @@ struct ArgMax { | ||||
|   } | ||||
| }; | ||||
|  | ||||
| bool simd_shuffle_down(bool data, uint16_t delta) { | ||||
|   return simd_shuffle_down(static_cast<uint32_t>(data), delta); | ||||
| } | ||||
|  | ||||
| uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) { | ||||
|   return as_type<uint64_t>(simd_shuffle_down(as_type<uint2>(data), delta)); | ||||
| } | ||||
|  | ||||
| int64_t simd_shuffle_down(int64_t data, uint16_t delta) { | ||||
|   return as_type<int64_t>(simd_shuffle_down(as_type<uint2>(data), delta)); | ||||
| } | ||||
|  | ||||
| template <typename U> | ||||
| IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) { | ||||
|   return IndexValPair<U>( | ||||
|   | ||||
| @@ -24,11 +24,59 @@ template <typename T, typename Op> | ||||
|       device otype *out [[buffer(1)]], \ | ||||
|       uint tid [[thread_position_in_grid]]); | ||||
|  | ||||
|  | ||||
| /////////////////////////////////////////////////////////////////////////////// | ||||
| // All reduce | ||||
| /////////////////////////////////////////////////////////////////////////////// | ||||
|  | ||||
| template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> | ||||
| inline U per_thread_all_reduce( | ||||
|     const device T *in, | ||||
|     const device size_t& in_size, | ||||
|     uint gid, | ||||
|     uint grid_size) { | ||||
|   Op op; | ||||
|   U total_val = Op::init; | ||||
|  | ||||
|   if (gid * N_READS < in_size) { | ||||
|     in += gid * N_READS; | ||||
|  | ||||
|     int r = 0; | ||||
|     for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) { | ||||
|       U vals[N_READS] = {op.init}; | ||||
|  | ||||
|       for(int i = 0; i < N_READS; i++) { | ||||
|         vals[i] = static_cast<U>(in[i]); | ||||
|       } | ||||
|       for(int i = 0; i < N_READS; i++) { | ||||
|         total_val = op(vals[i], total_val); | ||||
|       } | ||||
|  | ||||
|       in += grid_size * N_READS; | ||||
|     } | ||||
|  | ||||
|     // Separate case for the last set as we close the reduction size | ||||
|     size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS; | ||||
|     if (curr_idx < in_size) { | ||||
|       int max_reads = in_size - curr_idx; | ||||
|       T vals[N_READS]; | ||||
|  | ||||
|       for(int i = 0, idx = 0; i < N_READS; i++, idx++) { | ||||
|         idx = idx < max_reads ? idx : max_reads - 1; | ||||
|         vals[i] = in[idx]; | ||||
|       } | ||||
|       for(int i = 0; i < N_READS; i++) { | ||||
|         U val = i < max_reads ? vals[i] : Op::init; | ||||
|         total_val = op(static_cast<U>(val), total_val); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   return total_val; | ||||
| } | ||||
|  | ||||
| // NB: This kernel assumes threads_per_threadgroup is at most | ||||
| // 1024. This way with a simd_size of 32, we are guaranteed to | ||||
| // complete the reduction in two steps of simd-level reductions. | ||||
| template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> | ||||
| [[kernel]] void all_reduce( | ||||
|     const device T *in [[buffer(0)]], | ||||
| @@ -40,53 +88,18 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> | ||||
|     uint simd_per_group [[simdgroups_per_threadgroup]], | ||||
|     uint simd_lane_id [[thread_index_in_simdgroup]], | ||||
|     uint simd_group_id [[simdgroup_index_in_threadgroup]]) { | ||||
|   // NB: this kernel assumes threads_per_threadgroup is at most | ||||
|   // 1024. This way with a simd_size of 32, we are guaranteed to | ||||
|   // complete the reduction in two steps of simd-level reductions. | ||||
|  | ||||
|   Op op; | ||||
|   threadgroup U local_vals[simd_size]; | ||||
|        | ||||
|   U total_val = Op::init; | ||||
|  | ||||
|   in += gid * N_READS; | ||||
|        | ||||
|   int r = 0; | ||||
|   for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) { | ||||
|     U vals[N_READS] = {op.init};  | ||||
|  | ||||
|     for(int i = 0; i < N_READS; i++) { | ||||
|       vals[i] = static_cast<U>(in[i]); | ||||
|     } | ||||
|     for(int i = 0; i < N_READS; i++) { | ||||
|       total_val = op(vals[i], total_val); | ||||
|     } | ||||
|  | ||||
|     in += grid_size * N_READS; | ||||
|   } | ||||
|  | ||||
|   // Separate case for the last set as we close the reduction size  | ||||
|   size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS; | ||||
|   if (curr_idx < in_size) { | ||||
|     int max_reads = in_size - curr_idx; | ||||
|     T vals[N_READS]; | ||||
|  | ||||
|     for(int i = 0, idx = 0; i < N_READS; i++, idx++) { | ||||
|       idx = idx < max_reads ? idx : max_reads - 1; | ||||
|       vals[i] = in[idx]; | ||||
|     } | ||||
|     for(int i = 0; i < N_READS; i++) { | ||||
|       U val = i < max_reads ? vals[i] : Op::init;  | ||||
|       total_val = op(static_cast<U>(val), total_val); | ||||
|     } | ||||
|   } | ||||
|   U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size); | ||||
|  | ||||
|   // Reduction within simd group | ||||
|   total_val = op.simd_reduce(total_val); | ||||
|   if (simd_lane_id == 0) { | ||||
|     local_vals[simd_group_id] = total_val; | ||||
|   } | ||||
|    | ||||
|  | ||||
|   // Reduction within thread group | ||||
|   threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|   total_val = lid < simd_per_group ? local_vals[lid] : op.init; | ||||
| @@ -98,6 +111,46 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> | ||||
| [[kernel]] void all_reduce_no_atomics( | ||||
|     const device T *in [[buffer(0)]], | ||||
|     device U *out [[buffer(1)]], | ||||
|     const device size_t& in_size [[buffer(2)]], | ||||
|     uint gid [[thread_position_in_grid]], | ||||
|     uint lid [[thread_position_in_threadgroup]], | ||||
|     uint grid_size [[threads_per_grid]], | ||||
|     uint simd_per_group [[simdgroups_per_threadgroup]], | ||||
|     uint simd_lane_id [[thread_index_in_simdgroup]], | ||||
|     uint simd_group_id [[simdgroup_index_in_threadgroup]], | ||||
|     uint thread_group_id [[threadgroup_position_in_grid]]) { | ||||
|  | ||||
|   Op op; | ||||
|   threadgroup U local_vals[simd_size]; | ||||
|  | ||||
|   U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size); | ||||
|  | ||||
|   // Reduction within simd group (simd_add isn't supported for uint64/int64 types) | ||||
|   for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) { | ||||
|     total_val = op(total_val, simd_shuffle_down(total_val, lane_offset)); | ||||
|   } | ||||
|   // Write simd group reduction results to local memory | ||||
|   if (simd_lane_id == 0) { | ||||
|     local_vals[simd_group_id] = total_val; | ||||
|   } | ||||
|   threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|   // Reduction of simdgroup reduction results within threadgroup. | ||||
|   total_val = lid < simd_per_group ? local_vals[lid] : op.init; | ||||
|   for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) { | ||||
|     total_val = op(total_val, simd_shuffle_down(total_val, lane_offset)); | ||||
|   } | ||||
|  | ||||
|   // Reduction across threadgroups | ||||
|   if (lid == 0) { | ||||
|     out[thread_group_id] = total_val; | ||||
|   } | ||||
| } | ||||
|  | ||||
| #define instantiate_all_reduce(name, itype, otype, op) \ | ||||
|   template [[host_name("all_reduce_" #name)]] \ | ||||
|   [[kernel]] void all_reduce<itype, otype, op>( \ | ||||
| @@ -111,11 +164,80 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> | ||||
|       uint simd_lane_id [[thread_index_in_simdgroup]], \ | ||||
|       uint simd_group_id [[simdgroup_index_in_threadgroup]]); | ||||
|  | ||||
| #define instantiate_all_reduce_no_atomics(name, itype, otype, op) \ | ||||
|   template [[host_name("all_reduce_no_atomics_" #name)]] \ | ||||
|   [[kernel]] void all_reduce_no_atomics<itype, otype, op>( \ | ||||
|       const device itype *in [[buffer(0)]], \ | ||||
|       device otype *out [[buffer(1)]], \ | ||||
|       const device size_t& in_size [[buffer(2)]], \ | ||||
|       uint gid [[thread_position_in_grid]], \ | ||||
|       uint lid [[thread_position_in_threadgroup]], \ | ||||
|       uint grid_size [[threads_per_grid]], \ | ||||
|       uint simd_per_group [[simdgroups_per_threadgroup]], \ | ||||
|       uint simd_lane_id [[thread_index_in_simdgroup]], \ | ||||
|       uint simd_group_id [[simdgroup_index_in_threadgroup]], \ | ||||
|       uint thread_group_id [[threadgroup_position_in_grid]]); | ||||
|  | ||||
| /////////////////////////////////////////////////////////////////////////////// | ||||
| // Row atomics | ||||
| /////////////////////////////////////////////////////////////////////////////// | ||||
|  | ||||
| template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> | ||||
| inline U per_thread_row_reduce( | ||||
|     const device T *in, | ||||
|     const constant size_t& reduction_size, | ||||
|     const constant size_t& out_size, | ||||
|     const constant int* shape, | ||||
|     const constant size_t* strides, | ||||
|     const constant int& ndim, | ||||
|     uint lsize_x, | ||||
|     uint lid_x, | ||||
|     uint2 tid) { | ||||
|  | ||||
|   Op op; | ||||
|  | ||||
|   // Each threadgroup handles 1 reduction | ||||
|   // TODO: Specializing elem_to_loc would be slightly faster | ||||
|   int idx = tid.y * out_size + tid.x; | ||||
|   int extra_offset = elem_to_loc(idx, shape, strides, ndim); | ||||
|   in += extra_offset + lid_x * N_READS; | ||||
|    | ||||
|   // The reduction is accumulated here | ||||
|   U total_val = Op::init; | ||||
|  | ||||
|   // Loop over the reduction size within thread group | ||||
|   int r = 0; | ||||
|   for (; r < (int)ceildiv(reduction_size, N_READS*lsize_x) - 1; r++) { | ||||
|     T vals[N_READS];  | ||||
|     for(int i = 0; i < N_READS; i++) { | ||||
|       vals[i] = in[i]; | ||||
|     } | ||||
|     for(int i = 0; i < N_READS; i++) { | ||||
|       total_val = op(static_cast<U>(vals[i]), total_val); | ||||
|     } | ||||
|  | ||||
|     in += lsize_x * N_READS; | ||||
|   } | ||||
|  | ||||
|   // Separate case for the last set as we close the reduction size    | ||||
|   size_t reduction_index = (lid_x + (size_t)lsize_x * r) * N_READS; | ||||
|   if(reduction_index < reduction_size) { | ||||
|     int max_reads = reduction_size - reduction_index; | ||||
|  | ||||
|     T vals[N_READS];  | ||||
|     for(int i = 0; i < N_READS; i++) { | ||||
|       int idx = min(i, max_reads - 1); | ||||
|       vals[i] = static_cast<U>(in[idx]); | ||||
|     } | ||||
|     for(int i = 0; i < N_READS; i++) { | ||||
|       T val = i < max_reads ? vals[i] : Op::init; | ||||
|       total_val = op(static_cast<U>(val), total_val); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   return total_val; | ||||
| } | ||||
|  | ||||
| template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> | ||||
| [[kernel]] void row_reduce_general( | ||||
|     const device T *in [[buffer(0)]], | ||||
| @@ -133,46 +255,9 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> | ||||
|     uint simd_group_id [[simdgroup_index_in_threadgroup]]) { | ||||
|  | ||||
|   Op op; | ||||
|  | ||||
|   // Each threadgroup handles 1 reduction | ||||
|   // TODO: Specializing elem_to_loc would be slightly faster | ||||
|   int idx = tid.y * out_size + tid.x; | ||||
|   int extra_offset = elem_to_loc(idx, shape, strides, ndim); | ||||
|   in += extra_offset + lid.x * N_READS; | ||||
|    | ||||
|   // The reduction is accumulated here | ||||
|   U total_val = Op::init; | ||||
|   threadgroup U local_vals[simd_size]; | ||||
|  | ||||
|   // Loop over the reduction size within thread group | ||||
|   int r = 0; | ||||
|   for (; r < (int)ceildiv(reduction_size, N_READS*lsize.x) - 1; r++) { | ||||
|     T vals[N_READS];  | ||||
|     for(int i = 0; i < N_READS; i++) { | ||||
|       vals[i] = in[i]; | ||||
|     } | ||||
|     for(int i = 0; i < N_READS; i++) { | ||||
|       total_val = op(static_cast<U>(vals[i]), total_val); | ||||
|     } | ||||
|  | ||||
|     in += lsize.x * N_READS; | ||||
|   } | ||||
|  | ||||
|   // Separate case for the last set as we close the reduction size    | ||||
|   size_t reduction_index = (lid.x + (size_t)lsize.x * r) * N_READS; | ||||
|   if(reduction_index < reduction_size) { | ||||
|     int max_reads = reduction_size - reduction_index; | ||||
|  | ||||
|     T vals[N_READS];  | ||||
|     for(int i = 0; i < N_READS; i++) { | ||||
|       int idx = min(i, max_reads - 1); | ||||
|       vals[i] = static_cast<U>(in[idx]); | ||||
|     } | ||||
|     for(int i = 0; i < N_READS; i++) { | ||||
|       T val = i < max_reads ? vals[i] : Op::init; | ||||
|       total_val = op(static_cast<U>(val), total_val); | ||||
|     } | ||||
|   } | ||||
|   U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy); | ||||
|  | ||||
|   total_val = op.simd_reduce(total_val); | ||||
|    | ||||
| @@ -194,6 +279,53 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> | ||||
| [[kernel]] void row_reduce_general_no_atomics( | ||||
|     const device T *in [[buffer(0)]], | ||||
|     device U *out [[buffer(1)]], | ||||
|     const constant size_t& reduction_size [[buffer(2)]], | ||||
|     const constant size_t& out_size [[buffer(3)]], | ||||
|     const constant int* shape [[buffer(4)]], | ||||
|     const constant size_t* strides [[buffer(5)]], | ||||
|     const constant int& ndim [[buffer(6)]], | ||||
|     uint3 lid [[thread_position_in_threadgroup]], | ||||
|     uint3 lsize [[threads_per_threadgroup]], | ||||
|     uint3 gsize [[threads_per_grid]], | ||||
|     uint3 tid [[threadgroup_position_in_grid]], | ||||
|     uint simd_lane_id [[thread_index_in_simdgroup]], | ||||
|     uint simd_per_group [[simdgroups_per_threadgroup]], | ||||
|     uint simd_group_id [[simdgroup_index_in_threadgroup]]) { | ||||
|  | ||||
|   Op op; | ||||
|  | ||||
|   threadgroup U local_vals[simd_size]; | ||||
|   U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy); | ||||
|  | ||||
|   // Reduction within simd group - simd_add isn't supported for int64 types | ||||
|   for (uint16_t i = simd_size/2; i > 0; i /= 2) { | ||||
|     total_val = op(total_val, simd_shuffle_down(total_val, i)); | ||||
|   } | ||||
|  | ||||
|   // Prepare next level | ||||
|   if (simd_lane_id == 0) { | ||||
|     local_vals[simd_group_id] = total_val; | ||||
|   } | ||||
|   threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|   // Reduction within thread group | ||||
|   // Only needed if thread group has multiple simd groups | ||||
|   if(ceildiv(reduction_size, N_READS) > simd_size) { | ||||
|     total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init; | ||||
|     for (uint16_t i = simd_size/2; i > 0; i /= 2) { | ||||
|       total_val = op(total_val, simd_shuffle_down(total_val, i)); | ||||
|     } | ||||
|   } | ||||
|   // Write row reduce output for threadgroup with 1st thread in thread group | ||||
|   if (lid.x == 0) { | ||||
|     out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val; | ||||
|   } | ||||
| } | ||||
|  | ||||
| #define instantiate_row_reduce_general(name, itype, otype, op) \ | ||||
|   template [[host_name("row_reduce_general_" #name)]] \ | ||||
|   [[kernel]] void row_reduce_general<itype, otype, op>( \ | ||||
| @@ -211,52 +343,59 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> | ||||
|       uint simd_per_group [[simdgroups_per_threadgroup]],  \ | ||||
|       uint simd_group_id [[simdgroup_index_in_threadgroup]]); | ||||
|  | ||||
| #define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \ | ||||
|   template [[host_name("row_reduce_general_no_atomics_" #name)]] \ | ||||
|   [[kernel]] void row_reduce_general_no_atomics<itype, otype, op>( \ | ||||
|       const device itype *in [[buffer(0)]],  \ | ||||
|       device otype *out [[buffer(1)]],  \ | ||||
|       const constant size_t& reduction_size [[buffer(2)]],  \ | ||||
|       const constant size_t& out_size [[buffer(3)]],  \ | ||||
|       const constant int* shape [[buffer(4)]],  \ | ||||
|       const constant size_t* strides [[buffer(5)]],  \ | ||||
|       const constant int& ndim [[buffer(6)]],  \ | ||||
|       uint3 lid [[thread_position_in_threadgroup]],  \ | ||||
|       uint3 lsize [[threads_per_threadgroup]],  \ | ||||
|       uint3 gsize [[threads_per_grid]], \ | ||||
|       uint3 tid [[threadgroup_position_in_grid]],  \ | ||||
|       uint simd_lane_id [[thread_index_in_simdgroup]],  \ | ||||
|       uint simd_per_group [[simdgroups_per_threadgroup]],  \ | ||||
|       uint simd_group_id [[simdgroup_index_in_threadgroup]]); | ||||
|  | ||||
| /////////////////////////////////////////////////////////////////////////////// | ||||
| // Column reduce | ||||
| /////////////////////////////////////////////////////////////////////////////// | ||||
|  | ||||
| template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS> | ||||
| inline void _contiguous_strided_reduce( | ||||
|     const device T *in,  | ||||
|     device mlx_atomic<U> *out,  | ||||
|     threadgroup U *local_data,  | ||||
|     uint in_idx,  | ||||
|     uint out_idx,  | ||||
|     uint reduction_size,  | ||||
|     uint reduction_stride,  | ||||
|     uint2 tid,  | ||||
|     uint2 lid,  | ||||
| inline U _contiguous_strided_reduce( | ||||
|     const device T *in, | ||||
|     threadgroup U *local_data, | ||||
|     uint in_idx, | ||||
|     uint reduction_size, | ||||
|     uint reduction_stride, | ||||
|     uint2 tid, | ||||
|     uint2 lid, | ||||
|     uint2 lsize) { | ||||
|  | ||||
|   Op op; | ||||
|   T local_vals[N_READS];  | ||||
|   U total_val = Op::init; | ||||
|  | ||||
|   uint base_offset = (tid.y * lsize.y + lid.y) * N_READS; | ||||
|  | ||||
|   for(uint r = 0; r < N_READS; r++) { | ||||
|     uint offset = base_offset + r;  | ||||
|     offset = offset < reduction_size ? offset : reduction_size - 1;  | ||||
|     local_vals[r] = in[in_idx + offset * reduction_stride]; | ||||
|   } | ||||
|  | ||||
|   U total_val = Op::init;  | ||||
|   for(uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) { | ||||
|     total_val = op(static_cast<U>(total_val), local_vals[r]);  | ||||
|     uint offset = base_offset + r; | ||||
|     total_val = op(static_cast<U>(total_val), in[in_idx + offset * reduction_stride]); | ||||
|   } | ||||
|   local_data[lsize.y * lid.x + lid.y] = total_val;  | ||||
|  | ||||
|   local_data[lsize.y * lid.x + lid.y] = total_val; | ||||
|   threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|   U val = Op::init; | ||||
|   if(lid.y == 0) { | ||||
|     U val = op.init;  | ||||
|  | ||||
|     // Perform reduction across columns in thread group | ||||
|     for(uint i = 0; i < lsize.y; i++) { | ||||
|       val = op(val, local_data[lsize.y * lid.x + i]);  | ||||
|       val = op(val, local_data[lsize.y * lid.x + i]); | ||||
|     } | ||||
|  | ||||
|     op.atomic_update(out, val, out_idx); | ||||
|   } | ||||
|  | ||||
|   return val; | ||||
| } | ||||
|  | ||||
| template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS> | ||||
| @@ -265,13 +404,13 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS> | ||||
|     device mlx_atomic<U> *out [[buffer(1)]], | ||||
|     const constant size_t& reduction_size [[buffer(2)]], | ||||
|     const constant size_t& reduction_stride [[buffer(3)]], | ||||
|     const constant size_t& out_size [[buffer(4)]],  | ||||
|     const constant size_t& out_size [[buffer(4)]], | ||||
|     const constant int* shape [[buffer(5)]], | ||||
|     const constant size_t* strides [[buffer(6)]], | ||||
|     const constant int& ndim [[buffer(7)]], | ||||
|     threadgroup U *local_data [[threadgroup(0)]], | ||||
|     uint3 tid [[threadgroup_position_in_grid]],  | ||||
|     uint3 lid [[thread_position_in_threadgroup]],  | ||||
|     uint3 tid [[threadgroup_position_in_grid]], | ||||
|     uint3 lid [[thread_position_in_threadgroup]], | ||||
|     uint3 lsize [[threads_per_threadgroup]]) { | ||||
|   auto out_idx = tid.x * lsize.x + lid.x; | ||||
|   auto in_idx = elem_to_loc( | ||||
| @@ -281,18 +420,66 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS> | ||||
|     ndim | ||||
|   ); | ||||
|  | ||||
|   Op op; | ||||
|   if(out_idx < out_size) { | ||||
|     _contiguous_strided_reduce<T, U, Op, N_READS>( | ||||
|       in,  | ||||
|       out,  | ||||
|       local_data,  | ||||
|       in_idx, | ||||
|       out_idx,  | ||||
|       reduction_size,  | ||||
|       reduction_stride,  | ||||
|       tid.xy,  | ||||
|       lid.xy,  | ||||
|       lsize.xy);  | ||||
|     U val = _contiguous_strided_reduce<T, U, Op, N_READS>( | ||||
|               in, | ||||
|               local_data, | ||||
|               in_idx, | ||||
|               reduction_size, | ||||
|               reduction_stride, | ||||
|               tid.xy, | ||||
|               lid.xy, | ||||
|               lsize.xy); | ||||
|  | ||||
|     // Write out reduction results generated by threadgroups working on specific output element, contiguously. | ||||
|     if (lid.y == 0) { | ||||
|       op.atomic_update(out, val, out_idx); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS> | ||||
| [[kernel]] void col_reduce_general_no_atomics( | ||||
|     const device T *in [[buffer(0)]], | ||||
|     device U *out [[buffer(1)]], | ||||
|     const constant size_t& reduction_size [[buffer(2)]], | ||||
|     const constant size_t& reduction_stride [[buffer(3)]], | ||||
|     const constant size_t& out_size [[buffer(4)]], | ||||
|     const constant int* shape [[buffer(5)]], | ||||
|     const constant size_t* strides [[buffer(6)]], | ||||
|     const constant int& ndim [[buffer(7)]], | ||||
|     threadgroup U *local_data [[threadgroup(0)]], | ||||
|     uint3 tid [[threadgroup_position_in_grid]], | ||||
|     uint3 lid [[thread_position_in_threadgroup]], | ||||
|     uint3 gid [[thread_position_in_grid]], | ||||
|     uint3 lsize [[threads_per_threadgroup]], | ||||
|     uint3 gsize [[threads_per_grid]]) { | ||||
|   auto out_idx = tid.x * lsize.x + lid.x; | ||||
|   auto in_idx = elem_to_loc( | ||||
|     out_idx + tid.z * out_size, | ||||
|     shape, | ||||
|     strides, | ||||
|     ndim | ||||
|   ); | ||||
|  | ||||
|   if(out_idx < out_size) { | ||||
|     U val = _contiguous_strided_reduce<T, U, Op, N_READS>( | ||||
|               in, | ||||
|               local_data, | ||||
|               in_idx, | ||||
|               reduction_size, | ||||
|               reduction_stride, | ||||
|               tid.xy, | ||||
|               lid.xy, | ||||
|               lsize.xy); | ||||
|  | ||||
|     // Write out reduction results generated by threadgroups working on specific output element, contiguously. | ||||
|     if (lid.y == 0) { | ||||
|       uint tgsize_y = ceildiv(gsize.y, lsize.y); | ||||
|       uint tgsize_z = ceildiv(gsize.z, lsize.z); | ||||
|       out[tgsize_y * tgsize_z * gid.x + tgsize_y * tid.z + tid.y] = val; | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| @@ -312,6 +499,23 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS> | ||||
|       uint3 lid [[thread_position_in_threadgroup]], \ | ||||
|       uint3 lsize [[threads_per_threadgroup]]); | ||||
|  | ||||
| #define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \ | ||||
|   template [[host_name("col_reduce_general_no_atomics_" #name)]] \ | ||||
|   [[kernel]] void col_reduce_general_no_atomics<itype, otype, op>( \ | ||||
|       const device itype *in [[buffer(0)]], \ | ||||
|       device otype *out [[buffer(1)]], \ | ||||
|       const constant size_t& reduction_size [[buffer(2)]], \ | ||||
|       const constant size_t& reduction_stride [[buffer(3)]], \ | ||||
|       const constant size_t& out_size [[buffer(4)]], \ | ||||
|       const constant int* shape [[buffer(5)]],  \ | ||||
|       const constant size_t* strides [[buffer(6)]],  \ | ||||
|       const constant int& ndim [[buffer(7)]],  \ | ||||
|       threadgroup otype *local_data [[threadgroup(0)]], \ | ||||
|       uint3 tid [[threadgroup_position_in_grid]], \ | ||||
|       uint3 lid [[thread_position_in_threadgroup]], \ | ||||
|       uint3 gid [[thread_position_in_grid]], \ | ||||
|       uint3 lsize [[threads_per_threadgroup]], \ | ||||
|       uint3 gsize [[threads_per_grid]]); | ||||
|  | ||||
| /////////////////////////////////////////////////////////////////////////////// | ||||
| // Instantiations | ||||
| @@ -322,6 +526,15 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS> | ||||
|   instantiate_row_reduce_general(name, itype, otype, op) \ | ||||
|   instantiate_col_reduce_general(name, itype, otype, op) | ||||
|  | ||||
| #define instantiate_reduce_no_atomics(name, itype, otype, op) \ | ||||
|   instantiate_all_reduce_no_atomics(name, itype, otype, op) \ | ||||
|   instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \ | ||||
|   instantiate_col_reduce_general_no_atomics(name, itype, otype, op) | ||||
|  | ||||
| #define instantiate_same_reduce_no_atomics(name, tname, type, op) \ | ||||
|   instantiate_init_reduce(name ##tname, type, op<type>) \ | ||||
|   instantiate_reduce_no_atomics(name ##tname, type, type, op<type>) | ||||
|  | ||||
| #define instantiate_same_reduce(name, tname, type, op) \ | ||||
|   instantiate_init_reduce(name ##tname, type, op<type>) \ | ||||
|   instantiate_reduce(name ##tname, type, type, op<type>) | ||||
| @@ -353,6 +566,9 @@ instantiate_same_reduce(sum, int32, int32_t, Sum) | ||||
| instantiate_same_reduce(sum, float16, half, Sum) | ||||
| instantiate_same_reduce(sum, float32, float, Sum) | ||||
|  | ||||
| instantiate_same_reduce_no_atomics(sum, int64, int64_t, Sum) | ||||
| instantiate_same_reduce_no_atomics(sum, uint64, uint64_t, Sum) | ||||
|  | ||||
| instantiate_same_reduce(prod, uint8, uint8_t, Prod) | ||||
| instantiate_same_reduce(prod, uint16, uint16_t, Prod) | ||||
| instantiate_same_reduce(prod, uint32, uint32_t, Prod) | ||||
| @@ -362,6 +578,9 @@ instantiate_same_reduce(prod, int32, int32_t, Prod) | ||||
| instantiate_same_reduce(prod, float16, half, Prod) | ||||
| instantiate_same_reduce(prod, float32, float, Prod) | ||||
|  | ||||
| instantiate_same_reduce_no_atomics(prod, int64, int64_t, Prod) | ||||
| instantiate_same_reduce_no_atomics(prod, uint64, uint64_t, Prod) | ||||
|  | ||||
| instantiate_same_reduce(sum, bfloat16, bfloat16_t, Sum) | ||||
| instantiate_same_reduce(prod, bfloat16, bfloat16_t, Prod) | ||||
|  | ||||
| @@ -381,6 +600,9 @@ instantiate_same_reduce(min_, int32, int32_t, Min) | ||||
| instantiate_same_reduce(min_, float16, half, Min) | ||||
| instantiate_same_reduce(min_, float32, float, Min) | ||||
|  | ||||
| instantiate_same_reduce_no_atomics(min_, int64, int64_t, Min) | ||||
| instantiate_same_reduce_no_atomics(min_, uint64, uint64_t, Min) | ||||
|  | ||||
| instantiate_same_reduce(max_, uint8, uint8_t, Max) | ||||
| instantiate_same_reduce(max_, uint16, uint16_t, Max) | ||||
| instantiate_same_reduce(max_, uint32, uint32_t, Max) | ||||
| @@ -390,5 +612,8 @@ instantiate_same_reduce(max_, int32, int32_t, Max) | ||||
| instantiate_same_reduce(max_, float16, half, Max) | ||||
| instantiate_same_reduce(max_, float32, float, Max) | ||||
|  | ||||
| instantiate_same_reduce_no_atomics(max_, int64, int64_t, Max) | ||||
| instantiate_same_reduce_no_atomics(max_, uint64, uint64_t, Max) | ||||
|  | ||||
| instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min) | ||||
| instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max) | ||||
|   | ||||
| @@ -256,3 +256,21 @@ inline bfloat16_t log1p(bfloat16_t x) { | ||||
|  | ||||
|   return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f))); | ||||
| } | ||||
|  | ||||
| /////////////////////////////////////////////////////////////////////////////// | ||||
| // SIMD shuffle ops | ||||
| /////////////////////////////////////////////////////////////////////////////// | ||||
|  | ||||
| inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) { | ||||
|   return as_type<uint64_t>( | ||||
|       metal::simd_shuffle_down(as_type<uint2>(data), delta)); | ||||
| } | ||||
|  | ||||
| inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) { | ||||
|   return as_type<int64_t>( | ||||
|       metal::simd_shuffle_down(as_type<uint2>(data), delta)); | ||||
| } | ||||
|  | ||||
| inline bool simd_shuffle_down(bool data, uint16_t delta) { | ||||
|   return simd_shuffle_down(static_cast<uint32_t>(data), delta); | ||||
| } | ||||
| @@ -28,35 +28,40 @@ inline auto safe_divup(size_t n, size_t m) { | ||||
|   return safe_div(n, m) * m; | ||||
| } | ||||
|  | ||||
| inline bool is_64b_int(Dtype dtype) { | ||||
|   return dtype == int64 || dtype == uint64; | ||||
| } | ||||
|  | ||||
| // All Reduce | ||||
| void all_reduce_dispatch( | ||||
|     const array& in, | ||||
|     array& out, | ||||
|     const std::string& op_name, | ||||
|     MTL::ComputeCommandEncoder* compute_encoder, | ||||
|     metal::Device& d) { | ||||
|   // Get kernel and encode buffers | ||||
|   size_t in_size = in.size(); | ||||
|   auto kernel = d.get_kernel("all_reduce_" + op_name + type_to_name(in)); | ||||
|     metal::Device& d, | ||||
|     const Stream& s) { | ||||
|   Dtype out_dtype = out.dtype(); | ||||
|   bool is_out_64b_int = is_64b_int(out_dtype); | ||||
|   auto kernel = (is_out_64b_int) | ||||
|       ? d.get_kernel("all_reduce_no_atomics_" + op_name + type_to_name(in)) | ||||
|       : d.get_kernel("all_reduce_" + op_name + type_to_name(in)); | ||||
|  | ||||
|   compute_encoder->setComputePipelineState(kernel); | ||||
|   set_array_buffer(compute_encoder, in, 0); | ||||
|   set_array_buffer(compute_encoder, out, 1); | ||||
|   compute_encoder->setBytes(&in_size, sizeof(size_t), 2); | ||||
|  | ||||
|   // Set grid dimensions | ||||
|  | ||||
|   // We make sure each thread has enough to do by making it read in | ||||
|   // at least n_reads inputs | ||||
|   int n_reads = REDUCE_N_READS; | ||||
|   size_t in_size = in.size(); | ||||
|  | ||||
|   // mod_in_size gives us the groups of n_reads needed to go over the entire | ||||
|   // input | ||||
|   uint mod_in_size = (in_size + n_reads - 1) / n_reads; | ||||
|  | ||||
|   NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); | ||||
|   thread_group_size = | ||||
|       mod_in_size > thread_group_size ? thread_group_size : mod_in_size; | ||||
|   uint simd_size = kernel->threadExecutionWidth(); | ||||
|   thread_group_size = | ||||
|       ((thread_group_size + simd_size - 1) / simd_size) * simd_size; | ||||
|  | ||||
|   // If the number of thread groups needed exceeds 1024, we reuse threads groups | ||||
|   uint n_thread_groups = safe_div(mod_in_size, thread_group_size); | ||||
| @@ -66,7 +71,52 @@ void all_reduce_dispatch( | ||||
|   MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); | ||||
|   MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); | ||||
|  | ||||
|   compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|   // Encode buffers and dispatch | ||||
|   if (is_out_64b_int == false || n_thread_groups == 1) { | ||||
|     set_array_buffer(compute_encoder, in, 0); | ||||
|     set_array_buffer(compute_encoder, out, 1); | ||||
|     compute_encoder->setBytes(&in_size, sizeof(size_t), 2); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|  | ||||
|   } else { | ||||
|     // Allocate intermediate array to store partial reduction results | ||||
|     size_t intermediate_size = n_thread_groups; | ||||
|     array intermediate = | ||||
|         array({static_cast<int>(intermediate_size)}, out_dtype, nullptr, {}); | ||||
|     intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); | ||||
|     std::vector<array> intermediates = {intermediate}; | ||||
|  | ||||
|     // First dispatch | ||||
|     set_array_buffer(compute_encoder, in, 0); | ||||
|     set_array_buffer(compute_encoder, intermediate, 1); | ||||
|     compute_encoder->setBytes(&in_size, sizeof(size_t), 2); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|  | ||||
|     // Second pass to reduce intermediate reduction results written to DRAM | ||||
|     set_array_buffer(compute_encoder, intermediate, 0); | ||||
|     set_array_buffer(compute_encoder, out, 1); | ||||
|     compute_encoder->setBytes(&intermediate_size, sizeof(size_t), 2); | ||||
|  | ||||
|     mod_in_size = (intermediate_size + n_reads - 1) / n_reads; | ||||
|  | ||||
|     thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); | ||||
|     thread_group_size = | ||||
|         mod_in_size > thread_group_size ? thread_group_size : mod_in_size; | ||||
|     thread_group_size = | ||||
|         ((thread_group_size + simd_size - 1) / simd_size) * simd_size; | ||||
|  | ||||
|     // If the number of thread groups needed exceeds 1024, we reuse threads | ||||
|     // groups | ||||
|     nthreads = thread_group_size; | ||||
|     group_dims = MTL::Size(thread_group_size, 1, 1); | ||||
|     grid_dims = MTL::Size(nthreads, 1, 1); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|  | ||||
|     d.get_command_buffer(s.index)->addCompletedHandler( | ||||
|         [intermediates](MTL::CommandBuffer*) mutable { | ||||
|           intermediates.clear(); | ||||
|         }); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void row_reduce_general_dispatch( | ||||
| @@ -76,22 +126,31 @@ void row_reduce_general_dispatch( | ||||
|     const ReductionPlan& plan, | ||||
|     const std::vector<int>& axes, | ||||
|     MTL::ComputeCommandEncoder* compute_encoder, | ||||
|     metal::Device& d) { | ||||
|   auto kernel = | ||||
|       d.get_kernel("row_reduce_general_" + op_name + type_to_name(in)); | ||||
|     metal::Device& d, | ||||
|     const Stream& s) { | ||||
|   Dtype out_dtype = out.dtype(); | ||||
|   bool is_out_64b_int = is_64b_int(out_dtype); | ||||
|   auto kernel = (is_out_64b_int) | ||||
|       ? d.get_kernel( | ||||
|             "row_reduce_general_no_atomics_" + op_name + type_to_name(in)) | ||||
|       : d.get_kernel("row_reduce_general_" + op_name + type_to_name(in)); | ||||
|  | ||||
|   compute_encoder->setComputePipelineState(kernel); | ||||
|  | ||||
|   // Prepare the arguments for the kernel | ||||
|   int n_reads = REDUCE_N_READS; | ||||
|   size_t reduction_size = plan.shape.back(); | ||||
|   size_t out_size = out.size(); | ||||
|   auto shape = plan.shape; | ||||
|   auto strides = plan.strides; | ||||
|  | ||||
|   shape.pop_back(); | ||||
|   strides.pop_back(); | ||||
|  | ||||
|   size_t non_row_reductions = 1; | ||||
|   for (auto s : shape) { | ||||
|     non_row_reductions *= static_cast<size_t>(s); | ||||
|   } | ||||
|   size_t out_size = out.size(); | ||||
|   auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes); | ||||
|   for (auto s : rem_shape) { | ||||
|     shape.push_back(s); | ||||
| @@ -101,16 +160,6 @@ void row_reduce_general_dispatch( | ||||
|   } | ||||
|   int ndim = shape.size(); | ||||
|  | ||||
|   // Set the arguments for the kernel | ||||
|   compute_encoder->setComputePipelineState(kernel); | ||||
|   set_array_buffer(compute_encoder, in, 0); | ||||
|   set_array_buffer(compute_encoder, out, 1); | ||||
|   compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); | ||||
|   compute_encoder->setBytes(&out_size, sizeof(size_t), 3); | ||||
|   compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4); | ||||
|   compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 5); | ||||
|   compute_encoder->setBytes(&ndim, sizeof(int), 6); | ||||
|  | ||||
|   // Each thread group is responsible for 1 output | ||||
|   NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); | ||||
|   thread_group_size = | ||||
| @@ -127,7 +176,88 @@ void row_reduce_general_dispatch( | ||||
|   MTL::Size grid_dims = MTL::Size(n_threads, non_row_reductions, 1); | ||||
|   MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); | ||||
|  | ||||
|   compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|   if (is_out_64b_int == false || non_row_reductions == 1) { | ||||
|     // Set the arguments for the kernel | ||||
|     set_array_buffer(compute_encoder, in, 0); | ||||
|     set_array_buffer(compute_encoder, out, 1); | ||||
|     compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); | ||||
|     compute_encoder->setBytes(&out_size, sizeof(size_t), 3); | ||||
|     compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4); | ||||
|     compute_encoder->setBytes( | ||||
|         strides.data(), strides.size() * sizeof(size_t), 5); | ||||
|     compute_encoder->setBytes(&ndim, sizeof(int), 6); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|  | ||||
|   } else { | ||||
|     // Allocate intermediate array to store partial reduction results | ||||
|     array intermediate = array( | ||||
|         {static_cast<int>(out.size()), static_cast<int>(non_row_reductions)}, | ||||
|         out_dtype, | ||||
|         nullptr, | ||||
|         {}); | ||||
|     intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); | ||||
|     std::vector<array> intermediates = {intermediate}; | ||||
|  | ||||
|     // Set the arguments for the kernel | ||||
|     set_array_buffer(compute_encoder, in, 0); | ||||
|     set_array_buffer(compute_encoder, intermediate, 1); | ||||
|     compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); | ||||
|     compute_encoder->setBytes(&out_size, sizeof(size_t), 3); | ||||
|     compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4); | ||||
|     compute_encoder->setBytes( | ||||
|         strides.data(), strides.size() * sizeof(size_t), 5); | ||||
|     compute_encoder->setBytes(&ndim, sizeof(int), 6); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|  | ||||
|     // Set up second dispatch | ||||
|     reduction_size = non_row_reductions; | ||||
|     out_size = 1; | ||||
|  | ||||
|     // Shape of axes that aren't participating in reduction remains unchanged. | ||||
|     std::vector<int> new_shape = rem_shape; | ||||
|  | ||||
|     // Update their strides since they'll be different post partial reduction in | ||||
|     // first compute dispatch. | ||||
|     std::vector<size_t> new_strides = rem_strides; | ||||
|     new_strides.back() = reduction_size; | ||||
|     for (int i = new_shape.size() - 2; i >= 0; i--) { | ||||
|       new_strides[i] = new_shape[i + 1] * new_strides[i + 1]; | ||||
|     } | ||||
|     ndim = new_shape.size(); | ||||
|  | ||||
|     // Set the arguments for the kernel | ||||
|     set_array_buffer(compute_encoder, intermediate, 0); | ||||
|     set_array_buffer(compute_encoder, out, 1); | ||||
|     compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); | ||||
|     compute_encoder->setBytes(&out_size, sizeof(size_t), 3); | ||||
|     compute_encoder->setBytes( | ||||
|         new_shape.data(), new_shape.size() * sizeof(int), 4); | ||||
|     compute_encoder->setBytes( | ||||
|         new_strides.data(), new_strides.size() * sizeof(size_t), 5); | ||||
|     compute_encoder->setBytes(&ndim, sizeof(int), 6); | ||||
|  | ||||
|     // Each thread group is responsible for 1 output | ||||
|     thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); | ||||
|     thread_group_size = | ||||
|         std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size); | ||||
|  | ||||
|     // Align thread group size with simd_size | ||||
|     thread_group_size = | ||||
|         (thread_group_size + simd_size - 1) / simd_size * simd_size; | ||||
|     assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup()); | ||||
|  | ||||
|     // Launch enough thread groups for each output | ||||
|     n_threads = thread_group_size; | ||||
|     grid_dims = MTL::Size(n_threads, out.size(), 1); | ||||
|     group_dims = MTL::Size(thread_group_size, 1, 1); | ||||
|  | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|  | ||||
|     d.get_command_buffer(s.index)->addCompletedHandler( | ||||
|         [intermediates](MTL::CommandBuffer*) mutable { | ||||
|           intermediates.clear(); | ||||
|         }); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void strided_reduce_general_dispatch( | ||||
| @@ -137,9 +267,16 @@ void strided_reduce_general_dispatch( | ||||
|     const ReductionPlan& plan, | ||||
|     const std::vector<int>& axes, | ||||
|     MTL::ComputeCommandEncoder* compute_encoder, | ||||
|     metal::Device& d) { | ||||
|   auto kernel = | ||||
|       d.get_kernel("col_reduce_general_" + op_name + type_to_name(in)); | ||||
|     metal::Device& d, | ||||
|     const Stream& s) { | ||||
|   Dtype out_dtype = out.dtype(); | ||||
|   bool is_out_64b_int = is_64b_int(out_dtype); | ||||
|   auto kernel = (is_out_64b_int) | ||||
|       ? d.get_kernel( | ||||
|             "col_reduce_general_no_atomics_" + op_name + type_to_name(in)) | ||||
|       : d.get_kernel("col_reduce_general_" + op_name + type_to_name(in)); | ||||
|  | ||||
|   compute_encoder->setComputePipelineState(kernel); | ||||
|  | ||||
|   // Prepare the arguments for the kernel | ||||
|   size_t reduction_size = plan.shape.back(); | ||||
| @@ -162,19 +299,7 @@ void strided_reduce_general_dispatch( | ||||
|   } | ||||
|   int ndim = shape.size(); | ||||
|  | ||||
|   // Set the arguments for the kernel | ||||
|   compute_encoder->setComputePipelineState(kernel); | ||||
|   set_array_buffer(compute_encoder, in, 0); | ||||
|   set_array_buffer(compute_encoder, out, 1); | ||||
|   compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); | ||||
|   compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3); | ||||
|   compute_encoder->setBytes(&out_size, sizeof(size_t), 4); | ||||
|   compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5); | ||||
|   compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 6); | ||||
|   compute_encoder->setBytes(&ndim, sizeof(int), 7); | ||||
|  | ||||
|   // Select block dimensions | ||||
|  | ||||
|   // Each thread reads 16 inputs to give it more work | ||||
|   uint n_inputs_per_thread = REDUCE_N_READS; | ||||
|   uint n_threads_per_output = | ||||
| @@ -183,14 +308,22 @@ void strided_reduce_general_dispatch( | ||||
|   // We spread outputs over the x dimension and inputs over the y dimension | ||||
|   // Threads with the same lid.x in a given threadgroup work on the same | ||||
|   // output and each thread in the y dimension accumulates for that output | ||||
|  | ||||
|   // Threads with same lid.x, i.e. each column of threads work on same output | ||||
|   uint threadgroup_dim_x = std::min(out_size, 128ul); | ||||
|  | ||||
|   // Number of threads along y, is dependent on number of reductions needed. | ||||
|   uint threadgroup_dim_y = | ||||
|       kernel->maxTotalThreadsPerThreadgroup() / threadgroup_dim_x; | ||||
|   threadgroup_dim_y = std::min(n_threads_per_output, threadgroup_dim_y); | ||||
|  | ||||
|   // Derive number of thread groups along x, based on how many threads we need | ||||
|   // along x | ||||
|   uint n_threadgroups_x = | ||||
|       (out_size + threadgroup_dim_x - 1) / threadgroup_dim_x; | ||||
|  | ||||
|   // Derive number of thread groups along y based on how many threads we need | ||||
|   // along y | ||||
|   uint n_threadgroups_y = | ||||
|       (n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y; | ||||
|  | ||||
| @@ -199,18 +332,122 @@ void strided_reduce_general_dispatch( | ||||
|       MTL::Size(n_threadgroups_x, n_threadgroups_y, non_col_reductions); | ||||
|   MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1); | ||||
|  | ||||
|   // We set shared memory to be exploited here for reductions within a | ||||
|   // threadgroup - each thread must be able to update its accumulated output | ||||
|   // Note: Each threadgroup should have 32kB of data in threadgroup memory | ||||
|   //       and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design | ||||
|   //       This should be fine for floats, but we might need to revisit | ||||
|   //       if we ever come to doubles. In that case, we should also cut | ||||
|   //       down the number of threads we launch in a threadgroup | ||||
|   compute_encoder->setThreadgroupMemoryLength( | ||||
|       safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16), | ||||
|       0); | ||||
|   if (is_out_64b_int == false) { | ||||
|     // Set the arguments for the kernel | ||||
|     set_array_buffer(compute_encoder, in, 0); | ||||
|     set_array_buffer(compute_encoder, out, 1); | ||||
|     compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); | ||||
|     compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3); | ||||
|     compute_encoder->setBytes(&out_size, sizeof(size_t), 4); | ||||
|     compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5); | ||||
|     compute_encoder->setBytes( | ||||
|         strides.data(), strides.size() * sizeof(size_t), 6); | ||||
|     compute_encoder->setBytes(&ndim, sizeof(int), 7); | ||||
|  | ||||
|   compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|     // We set shared memory to be exploited here for reductions within a | ||||
|     // threadgroup - each thread must be able to update its accumulated output | ||||
|     // Note: Each threadgroup should have 32kB of data in threadgroup memory | ||||
|     //       and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design | ||||
|     //       This should be fine for floats, but we might need to revisit | ||||
|     //       if we ever come to doubles. In that case, we should also cut | ||||
|     //       down the number of threads we launch in a threadgroup | ||||
|     compute_encoder->setThreadgroupMemoryLength( | ||||
|         safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16), | ||||
|         0); | ||||
|     compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|  | ||||
|   } else { | ||||
|     // Allocate intermediate array to store reduction results from all thread | ||||
|     // groups | ||||
|     array intermediate = array( | ||||
|         {static_cast<int>(out.size()), | ||||
|          static_cast<int>(n_threadgroups_y * non_col_reductions)}, | ||||
|         out_dtype, | ||||
|         nullptr, | ||||
|         {}); | ||||
|     intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); | ||||
|     std::vector<array> intermediates = {intermediate}; | ||||
|  | ||||
|     // Set the arguments for the kernel | ||||
|     set_array_buffer(compute_encoder, in, 0); | ||||
|     set_array_buffer(compute_encoder, intermediate, 1); | ||||
|     compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); | ||||
|     compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3); | ||||
|     compute_encoder->setBytes(&out_size, sizeof(size_t), 4); | ||||
|     compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5); | ||||
|     compute_encoder->setBytes( | ||||
|         strides.data(), strides.size() * sizeof(size_t), 6); | ||||
|     compute_encoder->setBytes(&ndim, sizeof(int), 7); | ||||
|  | ||||
|     // We set shared memory to be exploited here for reductions within a | ||||
|     // threadgroup - each thread must be able to update its accumulated output | ||||
|     // Note: Each threadgroup should have 32kB of data in threadgroup memory | ||||
|     //       and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design | ||||
|     //       This should be fine for floats, but we might need to revisit | ||||
|     //       if we ever come to doubles. In that case, we should also cut | ||||
|     //       down the number of threads we launch in a threadgroup | ||||
|     compute_encoder->setThreadgroupMemoryLength( | ||||
|         safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16), | ||||
|         0); | ||||
|     compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|  | ||||
|     // Perform second pass of reductions | ||||
|     // Reduce results of threadgroups along y, z from first pass, that | ||||
|     // collectively work on each output element. | ||||
|     reduction_size = n_threadgroups_y * non_col_reductions; | ||||
|     out_size = 1; | ||||
|  | ||||
|     // Shape of axes that aren't participating in reduction remains unchanged. | ||||
|     std::vector<int> new_shape = rem_shape; | ||||
|  | ||||
|     // Update their strides since they'll be different after a partial reduction | ||||
|     // post first compute dispatch. | ||||
|     std::vector<size_t> new_strides = rem_strides; | ||||
|     new_strides.back() = reduction_size; | ||||
|     for (int i = new_shape.size() - 2; i >= 0; i--) { | ||||
|       new_strides[i] = new_shape[i + 1] * new_strides[i + 1]; | ||||
|     } | ||||
|     ndim = new_shape.size(); | ||||
|  | ||||
|     auto row_reduce_kernel = d.get_kernel( | ||||
|         "row_reduce_general_no_atomics_" + op_name + | ||||
|         type_to_name(intermediate)); | ||||
|     compute_encoder->setComputePipelineState(row_reduce_kernel); | ||||
|     set_array_buffer(compute_encoder, intermediate, 0); | ||||
|     set_array_buffer(compute_encoder, out, 1); | ||||
|     compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); | ||||
|     compute_encoder->setBytes(&out_size, sizeof(size_t), 3); | ||||
|     compute_encoder->setBytes( | ||||
|         new_shape.data(), new_shape.size() * sizeof(int), 4); | ||||
|     compute_encoder->setBytes( | ||||
|         new_strides.data(), new_strides.size() * sizeof(size_t), 5); | ||||
|     compute_encoder->setBytes(&ndim, sizeof(int), 6); | ||||
|  | ||||
|     // Each thread group is responsible for 1 output | ||||
|     size_t n_reads = REDUCE_N_READS; | ||||
|     size_t thread_group_size = | ||||
|         row_reduce_kernel->maxTotalThreadsPerThreadgroup(); | ||||
|     thread_group_size = | ||||
|         std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size); | ||||
|  | ||||
|     // Align thread group size with simd_size | ||||
|     uint simd_size = row_reduce_kernel->threadExecutionWidth(); | ||||
|     thread_group_size = | ||||
|         (thread_group_size + simd_size - 1) / simd_size * simd_size; | ||||
|     assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup()); | ||||
|  | ||||
|     // Launch enough thread groups for each output | ||||
|     uint n_threads = thread_group_size; | ||||
|     grid_dims = MTL::Size(n_threads, out.size(), 1); | ||||
|     group_dims = MTL::Size(thread_group_size, 1, 1); | ||||
|  | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|  | ||||
|     d.get_command_buffer(s.index)->addCompletedHandler( | ||||
|         [intermediates](MTL::CommandBuffer*) mutable { | ||||
|           intermediates.clear(); | ||||
|         }); | ||||
|   } | ||||
| } | ||||
|  | ||||
| } // namespace | ||||
| @@ -223,14 +460,6 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   assert(inputs.size() == 1); | ||||
|   array in = inputs[0]; | ||||
|  | ||||
|   // TODO: Allow specific row and column reductions with types disabled | ||||
|   // due to atomics ? | ||||
|   if (size_of(in.dtype()) == 8) { | ||||
|     std::ostringstream msg; | ||||
|     msg << "[Reduce::eval_gpu] Does not support " << in.dtype(); | ||||
|     throw std::runtime_error(msg.str()); | ||||
|   } | ||||
|  | ||||
|   // Make sure no identity reductions trickle down here | ||||
|   assert(!axes_.empty()); | ||||
|  | ||||
| @@ -297,7 +526,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|     // Reducing over everything and the data is all there no broadcasting or | ||||
|     // slicing etc. | ||||
|     if (plan.type == ContiguousAllReduce) { | ||||
|       all_reduce_dispatch(in, out, op_name, compute_encoder, d); | ||||
|       all_reduce_dispatch(in, out, op_name, compute_encoder, d, s); | ||||
|     } | ||||
|  | ||||
|     // At least the last dimension is row contiguous and we are reducing over | ||||
| @@ -305,7 +534,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|     else if ( | ||||
|         plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) { | ||||
|       row_reduce_general_dispatch( | ||||
|           in, out, op_name, plan, axes_, compute_encoder, d); | ||||
|           in, out, op_name, plan, axes_, compute_encoder, d, s); | ||||
|     } | ||||
|  | ||||
|     // At least the last two dimensions are contiguous and we are doing a | ||||
| @@ -314,7 +543,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|         plan.type == ContiguousStridedReduce || | ||||
|         plan.type == GeneralStridedReduce) { | ||||
|       strided_reduce_general_dispatch( | ||||
|           in, out, op_name, plan, axes_, compute_encoder, d); | ||||
|           in, out, op_name, plan, axes_, compute_encoder, d, s); | ||||
|     } | ||||
|  | ||||
|     if (!copies.empty()) { | ||||
|   | ||||
| @@ -55,6 +55,8 @@ class TestReduce(mlx_tests.MLXTestCase): | ||||
|             "uint8", | ||||
|             "uint16", | ||||
|             "uint32", | ||||
|             "int64", | ||||
|             "uint64", | ||||
|         ] | ||||
|         float_dtypes = ["float32"] | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Vijay Krish
					Vijay Krish