mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-25 20:58:13 +08:00 
			
		
		
		
	Add GPU support for uint64/int64 reductions (#569)
This commit is contained in:
		| @@ -63,18 +63,6 @@ struct ArgMax { | |||||||
|   } |   } | ||||||
| }; | }; | ||||||
|  |  | ||||||
| bool simd_shuffle_down(bool data, uint16_t delta) { |  | ||||||
|   return simd_shuffle_down(static_cast<uint32_t>(data), delta); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) { |  | ||||||
|   return as_type<uint64_t>(simd_shuffle_down(as_type<uint2>(data), delta)); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| int64_t simd_shuffle_down(int64_t data, uint16_t delta) { |  | ||||||
|   return as_type<int64_t>(simd_shuffle_down(as_type<uint2>(data), delta)); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| template <typename U> | template <typename U> | ||||||
| IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) { | IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) { | ||||||
|   return IndexValPair<U>( |   return IndexValPair<U>( | ||||||
|   | |||||||
| @@ -24,11 +24,59 @@ template <typename T, typename Op> | |||||||
|       device otype *out [[buffer(1)]], \ |       device otype *out [[buffer(1)]], \ | ||||||
|       uint tid [[thread_position_in_grid]]); |       uint tid [[thread_position_in_grid]]); | ||||||
|  |  | ||||||
|  |  | ||||||
| /////////////////////////////////////////////////////////////////////////////// | /////////////////////////////////////////////////////////////////////////////// | ||||||
| // All reduce | // All reduce | ||||||
| /////////////////////////////////////////////////////////////////////////////// | /////////////////////////////////////////////////////////////////////////////// | ||||||
|  |  | ||||||
|  | template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> | ||||||
|  | inline U per_thread_all_reduce( | ||||||
|  |     const device T *in, | ||||||
|  |     const device size_t& in_size, | ||||||
|  |     uint gid, | ||||||
|  |     uint grid_size) { | ||||||
|  |   Op op; | ||||||
|  |   U total_val = Op::init; | ||||||
|  |  | ||||||
|  |   if (gid * N_READS < in_size) { | ||||||
|  |     in += gid * N_READS; | ||||||
|  |  | ||||||
|  |     int r = 0; | ||||||
|  |     for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) { | ||||||
|  |       U vals[N_READS] = {op.init}; | ||||||
|  |  | ||||||
|  |       for(int i = 0; i < N_READS; i++) { | ||||||
|  |         vals[i] = static_cast<U>(in[i]); | ||||||
|  |       } | ||||||
|  |       for(int i = 0; i < N_READS; i++) { | ||||||
|  |         total_val = op(vals[i], total_val); | ||||||
|  |       } | ||||||
|  |  | ||||||
|  |       in += grid_size * N_READS; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // Separate case for the last set as we close the reduction size | ||||||
|  |     size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS; | ||||||
|  |     if (curr_idx < in_size) { | ||||||
|  |       int max_reads = in_size - curr_idx; | ||||||
|  |       T vals[N_READS]; | ||||||
|  |  | ||||||
|  |       for(int i = 0, idx = 0; i < N_READS; i++, idx++) { | ||||||
|  |         idx = idx < max_reads ? idx : max_reads - 1; | ||||||
|  |         vals[i] = in[idx]; | ||||||
|  |       } | ||||||
|  |       for(int i = 0; i < N_READS; i++) { | ||||||
|  |         U val = i < max_reads ? vals[i] : Op::init; | ||||||
|  |         total_val = op(static_cast<U>(val), total_val); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   return total_val; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // NB: This kernel assumes threads_per_threadgroup is at most | ||||||
|  | // 1024. This way with a simd_size of 32, we are guaranteed to | ||||||
|  | // complete the reduction in two steps of simd-level reductions. | ||||||
| template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> | template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> | ||||||
| [[kernel]] void all_reduce( | [[kernel]] void all_reduce( | ||||||
|     const device T *in [[buffer(0)]], |     const device T *in [[buffer(0)]], | ||||||
| @@ -40,53 +88,18 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> | |||||||
|     uint simd_per_group [[simdgroups_per_threadgroup]], |     uint simd_per_group [[simdgroups_per_threadgroup]], | ||||||
|     uint simd_lane_id [[thread_index_in_simdgroup]], |     uint simd_lane_id [[thread_index_in_simdgroup]], | ||||||
|     uint simd_group_id [[simdgroup_index_in_threadgroup]]) { |     uint simd_group_id [[simdgroup_index_in_threadgroup]]) { | ||||||
|   // NB: this kernel assumes threads_per_threadgroup is at most |  | ||||||
|   // 1024. This way with a simd_size of 32, we are guaranteed to |  | ||||||
|   // complete the reduction in two steps of simd-level reductions. |  | ||||||
|  |  | ||||||
|   Op op; |   Op op; | ||||||
|   threadgroup U local_vals[simd_size]; |   threadgroup U local_vals[simd_size]; | ||||||
|        |  | ||||||
|   U total_val = Op::init; |  | ||||||
|  |  | ||||||
|   in += gid * N_READS; |   U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size); | ||||||
|        |  | ||||||
|   int r = 0; |  | ||||||
|   for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) { |  | ||||||
|     U vals[N_READS] = {op.init};  |  | ||||||
|  |  | ||||||
|     for(int i = 0; i < N_READS; i++) { |  | ||||||
|       vals[i] = static_cast<U>(in[i]); |  | ||||||
|     } |  | ||||||
|     for(int i = 0; i < N_READS; i++) { |  | ||||||
|       total_val = op(vals[i], total_val); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     in += grid_size * N_READS; |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   // Separate case for the last set as we close the reduction size  |  | ||||||
|   size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS; |  | ||||||
|   if (curr_idx < in_size) { |  | ||||||
|     int max_reads = in_size - curr_idx; |  | ||||||
|     T vals[N_READS]; |  | ||||||
|  |  | ||||||
|     for(int i = 0, idx = 0; i < N_READS; i++, idx++) { |  | ||||||
|       idx = idx < max_reads ? idx : max_reads - 1; |  | ||||||
|       vals[i] = in[idx]; |  | ||||||
|     } |  | ||||||
|     for(int i = 0; i < N_READS; i++) { |  | ||||||
|       U val = i < max_reads ? vals[i] : Op::init;  |  | ||||||
|       total_val = op(static_cast<U>(val), total_val); |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   // Reduction within simd group |   // Reduction within simd group | ||||||
|   total_val = op.simd_reduce(total_val); |   total_val = op.simd_reduce(total_val); | ||||||
|   if (simd_lane_id == 0) { |   if (simd_lane_id == 0) { | ||||||
|     local_vals[simd_group_id] = total_val; |     local_vals[simd_group_id] = total_val; | ||||||
|   } |   } | ||||||
|    |  | ||||||
|   // Reduction within thread group |   // Reduction within thread group | ||||||
|   threadgroup_barrier(mem_flags::mem_threadgroup); |   threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|   total_val = lid < simd_per_group ? local_vals[lid] : op.init; |   total_val = lid < simd_per_group ? local_vals[lid] : op.init; | ||||||
| @@ -98,6 +111,46 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> | |||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> | ||||||
|  | [[kernel]] void all_reduce_no_atomics( | ||||||
|  |     const device T *in [[buffer(0)]], | ||||||
|  |     device U *out [[buffer(1)]], | ||||||
|  |     const device size_t& in_size [[buffer(2)]], | ||||||
|  |     uint gid [[thread_position_in_grid]], | ||||||
|  |     uint lid [[thread_position_in_threadgroup]], | ||||||
|  |     uint grid_size [[threads_per_grid]], | ||||||
|  |     uint simd_per_group [[simdgroups_per_threadgroup]], | ||||||
|  |     uint simd_lane_id [[thread_index_in_simdgroup]], | ||||||
|  |     uint simd_group_id [[simdgroup_index_in_threadgroup]], | ||||||
|  |     uint thread_group_id [[threadgroup_position_in_grid]]) { | ||||||
|  |  | ||||||
|  |   Op op; | ||||||
|  |   threadgroup U local_vals[simd_size]; | ||||||
|  |  | ||||||
|  |   U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size); | ||||||
|  |  | ||||||
|  |   // Reduction within simd group (simd_add isn't supported for uint64/int64 types) | ||||||
|  |   for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) { | ||||||
|  |     total_val = op(total_val, simd_shuffle_down(total_val, lane_offset)); | ||||||
|  |   } | ||||||
|  |   // Write simd group reduction results to local memory | ||||||
|  |   if (simd_lane_id == 0) { | ||||||
|  |     local_vals[simd_group_id] = total_val; | ||||||
|  |   } | ||||||
|  |   threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|  |  | ||||||
|  |   // Reduction of simdgroup reduction results within threadgroup. | ||||||
|  |   total_val = lid < simd_per_group ? local_vals[lid] : op.init; | ||||||
|  |   for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) { | ||||||
|  |     total_val = op(total_val, simd_shuffle_down(total_val, lane_offset)); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // Reduction across threadgroups | ||||||
|  |   if (lid == 0) { | ||||||
|  |     out[thread_group_id] = total_val; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
| #define instantiate_all_reduce(name, itype, otype, op) \ | #define instantiate_all_reduce(name, itype, otype, op) \ | ||||||
|   template [[host_name("all_reduce_" #name)]] \ |   template [[host_name("all_reduce_" #name)]] \ | ||||||
|   [[kernel]] void all_reduce<itype, otype, op>( \ |   [[kernel]] void all_reduce<itype, otype, op>( \ | ||||||
| @@ -111,11 +164,80 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> | |||||||
|       uint simd_lane_id [[thread_index_in_simdgroup]], \ |       uint simd_lane_id [[thread_index_in_simdgroup]], \ | ||||||
|       uint simd_group_id [[simdgroup_index_in_threadgroup]]); |       uint simd_group_id [[simdgroup_index_in_threadgroup]]); | ||||||
|  |  | ||||||
|  | #define instantiate_all_reduce_no_atomics(name, itype, otype, op) \ | ||||||
|  |   template [[host_name("all_reduce_no_atomics_" #name)]] \ | ||||||
|  |   [[kernel]] void all_reduce_no_atomics<itype, otype, op>( \ | ||||||
|  |       const device itype *in [[buffer(0)]], \ | ||||||
|  |       device otype *out [[buffer(1)]], \ | ||||||
|  |       const device size_t& in_size [[buffer(2)]], \ | ||||||
|  |       uint gid [[thread_position_in_grid]], \ | ||||||
|  |       uint lid [[thread_position_in_threadgroup]], \ | ||||||
|  |       uint grid_size [[threads_per_grid]], \ | ||||||
|  |       uint simd_per_group [[simdgroups_per_threadgroup]], \ | ||||||
|  |       uint simd_lane_id [[thread_index_in_simdgroup]], \ | ||||||
|  |       uint simd_group_id [[simdgroup_index_in_threadgroup]], \ | ||||||
|  |       uint thread_group_id [[threadgroup_position_in_grid]]); | ||||||
|  |  | ||||||
| /////////////////////////////////////////////////////////////////////////////// | /////////////////////////////////////////////////////////////////////////////// | ||||||
| // Row atomics | // Row atomics | ||||||
| /////////////////////////////////////////////////////////////////////////////// | /////////////////////////////////////////////////////////////////////////////// | ||||||
|  |  | ||||||
|  | template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> | ||||||
|  | inline U per_thread_row_reduce( | ||||||
|  |     const device T *in, | ||||||
|  |     const constant size_t& reduction_size, | ||||||
|  |     const constant size_t& out_size, | ||||||
|  |     const constant int* shape, | ||||||
|  |     const constant size_t* strides, | ||||||
|  |     const constant int& ndim, | ||||||
|  |     uint lsize_x, | ||||||
|  |     uint lid_x, | ||||||
|  |     uint2 tid) { | ||||||
|  |  | ||||||
|  |   Op op; | ||||||
|  |  | ||||||
|  |   // Each threadgroup handles 1 reduction | ||||||
|  |   // TODO: Specializing elem_to_loc would be slightly faster | ||||||
|  |   int idx = tid.y * out_size + tid.x; | ||||||
|  |   int extra_offset = elem_to_loc(idx, shape, strides, ndim); | ||||||
|  |   in += extra_offset + lid_x * N_READS; | ||||||
|  |    | ||||||
|  |   // The reduction is accumulated here | ||||||
|  |   U total_val = Op::init; | ||||||
|  |  | ||||||
|  |   // Loop over the reduction size within thread group | ||||||
|  |   int r = 0; | ||||||
|  |   for (; r < (int)ceildiv(reduction_size, N_READS*lsize_x) - 1; r++) { | ||||||
|  |     T vals[N_READS];  | ||||||
|  |     for(int i = 0; i < N_READS; i++) { | ||||||
|  |       vals[i] = in[i]; | ||||||
|  |     } | ||||||
|  |     for(int i = 0; i < N_READS; i++) { | ||||||
|  |       total_val = op(static_cast<U>(vals[i]), total_val); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     in += lsize_x * N_READS; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // Separate case for the last set as we close the reduction size    | ||||||
|  |   size_t reduction_index = (lid_x + (size_t)lsize_x * r) * N_READS; | ||||||
|  |   if(reduction_index < reduction_size) { | ||||||
|  |     int max_reads = reduction_size - reduction_index; | ||||||
|  |  | ||||||
|  |     T vals[N_READS];  | ||||||
|  |     for(int i = 0; i < N_READS; i++) { | ||||||
|  |       int idx = min(i, max_reads - 1); | ||||||
|  |       vals[i] = static_cast<U>(in[idx]); | ||||||
|  |     } | ||||||
|  |     for(int i = 0; i < N_READS; i++) { | ||||||
|  |       T val = i < max_reads ? vals[i] : Op::init; | ||||||
|  |       total_val = op(static_cast<U>(val), total_val); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   return total_val; | ||||||
|  | } | ||||||
|  |  | ||||||
| template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> | template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> | ||||||
| [[kernel]] void row_reduce_general( | [[kernel]] void row_reduce_general( | ||||||
|     const device T *in [[buffer(0)]], |     const device T *in [[buffer(0)]], | ||||||
| @@ -133,46 +255,9 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> | |||||||
|     uint simd_group_id [[simdgroup_index_in_threadgroup]]) { |     uint simd_group_id [[simdgroup_index_in_threadgroup]]) { | ||||||
|  |  | ||||||
|   Op op; |   Op op; | ||||||
|  |  | ||||||
|   // Each threadgroup handles 1 reduction |  | ||||||
|   // TODO: Specializing elem_to_loc would be slightly faster |  | ||||||
|   int idx = tid.y * out_size + tid.x; |  | ||||||
|   int extra_offset = elem_to_loc(idx, shape, strides, ndim); |  | ||||||
|   in += extra_offset + lid.x * N_READS; |  | ||||||
|    |  | ||||||
|   // The reduction is accumulated here |  | ||||||
|   U total_val = Op::init; |  | ||||||
|   threadgroup U local_vals[simd_size]; |   threadgroup U local_vals[simd_size]; | ||||||
|  |  | ||||||
|   // Loop over the reduction size within thread group |   U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy); | ||||||
|   int r = 0; |  | ||||||
|   for (; r < (int)ceildiv(reduction_size, N_READS*lsize.x) - 1; r++) { |  | ||||||
|     T vals[N_READS];  |  | ||||||
|     for(int i = 0; i < N_READS; i++) { |  | ||||||
|       vals[i] = in[i]; |  | ||||||
|     } |  | ||||||
|     for(int i = 0; i < N_READS; i++) { |  | ||||||
|       total_val = op(static_cast<U>(vals[i]), total_val); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     in += lsize.x * N_READS; |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   // Separate case for the last set as we close the reduction size    |  | ||||||
|   size_t reduction_index = (lid.x + (size_t)lsize.x * r) * N_READS; |  | ||||||
|   if(reduction_index < reduction_size) { |  | ||||||
|     int max_reads = reduction_size - reduction_index; |  | ||||||
|  |  | ||||||
|     T vals[N_READS];  |  | ||||||
|     for(int i = 0; i < N_READS; i++) { |  | ||||||
|       int idx = min(i, max_reads - 1); |  | ||||||
|       vals[i] = static_cast<U>(in[idx]); |  | ||||||
|     } |  | ||||||
|     for(int i = 0; i < N_READS; i++) { |  | ||||||
|       T val = i < max_reads ? vals[i] : Op::init; |  | ||||||
|       total_val = op(static_cast<U>(val), total_val); |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   total_val = op.simd_reduce(total_val); |   total_val = op.simd_reduce(total_val); | ||||||
|    |    | ||||||
| @@ -194,6 +279,53 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> | |||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> | ||||||
|  | [[kernel]] void row_reduce_general_no_atomics( | ||||||
|  |     const device T *in [[buffer(0)]], | ||||||
|  |     device U *out [[buffer(1)]], | ||||||
|  |     const constant size_t& reduction_size [[buffer(2)]], | ||||||
|  |     const constant size_t& out_size [[buffer(3)]], | ||||||
|  |     const constant int* shape [[buffer(4)]], | ||||||
|  |     const constant size_t* strides [[buffer(5)]], | ||||||
|  |     const constant int& ndim [[buffer(6)]], | ||||||
|  |     uint3 lid [[thread_position_in_threadgroup]], | ||||||
|  |     uint3 lsize [[threads_per_threadgroup]], | ||||||
|  |     uint3 gsize [[threads_per_grid]], | ||||||
|  |     uint3 tid [[threadgroup_position_in_grid]], | ||||||
|  |     uint simd_lane_id [[thread_index_in_simdgroup]], | ||||||
|  |     uint simd_per_group [[simdgroups_per_threadgroup]], | ||||||
|  |     uint simd_group_id [[simdgroup_index_in_threadgroup]]) { | ||||||
|  |  | ||||||
|  |   Op op; | ||||||
|  |  | ||||||
|  |   threadgroup U local_vals[simd_size]; | ||||||
|  |   U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy); | ||||||
|  |  | ||||||
|  |   // Reduction within simd group - simd_add isn't supported for int64 types | ||||||
|  |   for (uint16_t i = simd_size/2; i > 0; i /= 2) { | ||||||
|  |     total_val = op(total_val, simd_shuffle_down(total_val, i)); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // Prepare next level | ||||||
|  |   if (simd_lane_id == 0) { | ||||||
|  |     local_vals[simd_group_id] = total_val; | ||||||
|  |   } | ||||||
|  |   threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|  |  | ||||||
|  |   // Reduction within thread group | ||||||
|  |   // Only needed if thread group has multiple simd groups | ||||||
|  |   if(ceildiv(reduction_size, N_READS) > simd_size) { | ||||||
|  |     total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init; | ||||||
|  |     for (uint16_t i = simd_size/2; i > 0; i /= 2) { | ||||||
|  |       total_val = op(total_val, simd_shuffle_down(total_val, i)); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   // Write row reduce output for threadgroup with 1st thread in thread group | ||||||
|  |   if (lid.x == 0) { | ||||||
|  |     out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
| #define instantiate_row_reduce_general(name, itype, otype, op) \ | #define instantiate_row_reduce_general(name, itype, otype, op) \ | ||||||
|   template [[host_name("row_reduce_general_" #name)]] \ |   template [[host_name("row_reduce_general_" #name)]] \ | ||||||
|   [[kernel]] void row_reduce_general<itype, otype, op>( \ |   [[kernel]] void row_reduce_general<itype, otype, op>( \ | ||||||
| @@ -211,52 +343,59 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> | |||||||
|       uint simd_per_group [[simdgroups_per_threadgroup]],  \ |       uint simd_per_group [[simdgroups_per_threadgroup]],  \ | ||||||
|       uint simd_group_id [[simdgroup_index_in_threadgroup]]); |       uint simd_group_id [[simdgroup_index_in_threadgroup]]); | ||||||
|  |  | ||||||
|  | #define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \ | ||||||
|  |   template [[host_name("row_reduce_general_no_atomics_" #name)]] \ | ||||||
|  |   [[kernel]] void row_reduce_general_no_atomics<itype, otype, op>( \ | ||||||
|  |       const device itype *in [[buffer(0)]],  \ | ||||||
|  |       device otype *out [[buffer(1)]],  \ | ||||||
|  |       const constant size_t& reduction_size [[buffer(2)]],  \ | ||||||
|  |       const constant size_t& out_size [[buffer(3)]],  \ | ||||||
|  |       const constant int* shape [[buffer(4)]],  \ | ||||||
|  |       const constant size_t* strides [[buffer(5)]],  \ | ||||||
|  |       const constant int& ndim [[buffer(6)]],  \ | ||||||
|  |       uint3 lid [[thread_position_in_threadgroup]],  \ | ||||||
|  |       uint3 lsize [[threads_per_threadgroup]],  \ | ||||||
|  |       uint3 gsize [[threads_per_grid]], \ | ||||||
|  |       uint3 tid [[threadgroup_position_in_grid]],  \ | ||||||
|  |       uint simd_lane_id [[thread_index_in_simdgroup]],  \ | ||||||
|  |       uint simd_per_group [[simdgroups_per_threadgroup]],  \ | ||||||
|  |       uint simd_group_id [[simdgroup_index_in_threadgroup]]); | ||||||
|  |  | ||||||
| /////////////////////////////////////////////////////////////////////////////// | /////////////////////////////////////////////////////////////////////////////// | ||||||
| // Column reduce | // Column reduce | ||||||
| /////////////////////////////////////////////////////////////////////////////// | /////////////////////////////////////////////////////////////////////////////// | ||||||
|  |  | ||||||
| template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS> | template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS> | ||||||
| inline void _contiguous_strided_reduce( | inline U _contiguous_strided_reduce( | ||||||
|     const device T *in,  |     const device T *in, | ||||||
|     device mlx_atomic<U> *out,  |     threadgroup U *local_data, | ||||||
|     threadgroup U *local_data,  |     uint in_idx, | ||||||
|     uint in_idx,  |     uint reduction_size, | ||||||
|     uint out_idx,  |     uint reduction_stride, | ||||||
|     uint reduction_size,  |     uint2 tid, | ||||||
|     uint reduction_stride,  |     uint2 lid, | ||||||
|     uint2 tid,  |  | ||||||
|     uint2 lid,  |  | ||||||
|     uint2 lsize) { |     uint2 lsize) { | ||||||
|  |  | ||||||
|   Op op; |   Op op; | ||||||
|   T local_vals[N_READS];  |   U total_val = Op::init; | ||||||
|  |  | ||||||
|   uint base_offset = (tid.y * lsize.y + lid.y) * N_READS; |   uint base_offset = (tid.y * lsize.y + lid.y) * N_READS; | ||||||
|  |  | ||||||
|   for(uint r = 0; r < N_READS; r++) { |  | ||||||
|     uint offset = base_offset + r;  |  | ||||||
|     offset = offset < reduction_size ? offset : reduction_size - 1;  |  | ||||||
|     local_vals[r] = in[in_idx + offset * reduction_stride]; |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   U total_val = Op::init;  |  | ||||||
|   for(uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) { |   for(uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) { | ||||||
|     total_val = op(static_cast<U>(total_val), local_vals[r]);  |     uint offset = base_offset + r; | ||||||
|  |     total_val = op(static_cast<U>(total_val), in[in_idx + offset * reduction_stride]); | ||||||
|   } |   } | ||||||
|   local_data[lsize.y * lid.x + lid.y] = total_val;  |   local_data[lsize.y * lid.x + lid.y] = total_val; | ||||||
|  |  | ||||||
|   threadgroup_barrier(mem_flags::mem_threadgroup); |   threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|  |  | ||||||
|  |   U val = Op::init; | ||||||
|   if(lid.y == 0) { |   if(lid.y == 0) { | ||||||
|     U val = op.init;  |     // Perform reduction across columns in thread group | ||||||
|  |  | ||||||
|     for(uint i = 0; i < lsize.y; i++) { |     for(uint i = 0; i < lsize.y; i++) { | ||||||
|       val = op(val, local_data[lsize.y * lid.x + i]);  |       val = op(val, local_data[lsize.y * lid.x + i]); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     op.atomic_update(out, val, out_idx); |  | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  |   return val; | ||||||
| } | } | ||||||
|  |  | ||||||
| template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS> | template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS> | ||||||
| @@ -265,13 +404,13 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS> | |||||||
|     device mlx_atomic<U> *out [[buffer(1)]], |     device mlx_atomic<U> *out [[buffer(1)]], | ||||||
|     const constant size_t& reduction_size [[buffer(2)]], |     const constant size_t& reduction_size [[buffer(2)]], | ||||||
|     const constant size_t& reduction_stride [[buffer(3)]], |     const constant size_t& reduction_stride [[buffer(3)]], | ||||||
|     const constant size_t& out_size [[buffer(4)]],  |     const constant size_t& out_size [[buffer(4)]], | ||||||
|     const constant int* shape [[buffer(5)]], |     const constant int* shape [[buffer(5)]], | ||||||
|     const constant size_t* strides [[buffer(6)]], |     const constant size_t* strides [[buffer(6)]], | ||||||
|     const constant int& ndim [[buffer(7)]], |     const constant int& ndim [[buffer(7)]], | ||||||
|     threadgroup U *local_data [[threadgroup(0)]], |     threadgroup U *local_data [[threadgroup(0)]], | ||||||
|     uint3 tid [[threadgroup_position_in_grid]],  |     uint3 tid [[threadgroup_position_in_grid]], | ||||||
|     uint3 lid [[thread_position_in_threadgroup]],  |     uint3 lid [[thread_position_in_threadgroup]], | ||||||
|     uint3 lsize [[threads_per_threadgroup]]) { |     uint3 lsize [[threads_per_threadgroup]]) { | ||||||
|   auto out_idx = tid.x * lsize.x + lid.x; |   auto out_idx = tid.x * lsize.x + lid.x; | ||||||
|   auto in_idx = elem_to_loc( |   auto in_idx = elem_to_loc( | ||||||
| @@ -281,18 +420,66 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS> | |||||||
|     ndim |     ndim | ||||||
|   ); |   ); | ||||||
|  |  | ||||||
|  |   Op op; | ||||||
|   if(out_idx < out_size) { |   if(out_idx < out_size) { | ||||||
|     _contiguous_strided_reduce<T, U, Op, N_READS>( |     U val = _contiguous_strided_reduce<T, U, Op, N_READS>( | ||||||
|       in,  |               in, | ||||||
|       out,  |               local_data, | ||||||
|       local_data,  |               in_idx, | ||||||
|       in_idx, |               reduction_size, | ||||||
|       out_idx,  |               reduction_stride, | ||||||
|       reduction_size,  |               tid.xy, | ||||||
|       reduction_stride,  |               lid.xy, | ||||||
|       tid.xy,  |               lsize.xy); | ||||||
|       lid.xy,  |  | ||||||
|       lsize.xy);  |     // Write out reduction results generated by threadgroups working on specific output element, contiguously. | ||||||
|  |     if (lid.y == 0) { | ||||||
|  |       op.atomic_update(out, val, out_idx); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS> | ||||||
|  | [[kernel]] void col_reduce_general_no_atomics( | ||||||
|  |     const device T *in [[buffer(0)]], | ||||||
|  |     device U *out [[buffer(1)]], | ||||||
|  |     const constant size_t& reduction_size [[buffer(2)]], | ||||||
|  |     const constant size_t& reduction_stride [[buffer(3)]], | ||||||
|  |     const constant size_t& out_size [[buffer(4)]], | ||||||
|  |     const constant int* shape [[buffer(5)]], | ||||||
|  |     const constant size_t* strides [[buffer(6)]], | ||||||
|  |     const constant int& ndim [[buffer(7)]], | ||||||
|  |     threadgroup U *local_data [[threadgroup(0)]], | ||||||
|  |     uint3 tid [[threadgroup_position_in_grid]], | ||||||
|  |     uint3 lid [[thread_position_in_threadgroup]], | ||||||
|  |     uint3 gid [[thread_position_in_grid]], | ||||||
|  |     uint3 lsize [[threads_per_threadgroup]], | ||||||
|  |     uint3 gsize [[threads_per_grid]]) { | ||||||
|  |   auto out_idx = tid.x * lsize.x + lid.x; | ||||||
|  |   auto in_idx = elem_to_loc( | ||||||
|  |     out_idx + tid.z * out_size, | ||||||
|  |     shape, | ||||||
|  |     strides, | ||||||
|  |     ndim | ||||||
|  |   ); | ||||||
|  |  | ||||||
|  |   if(out_idx < out_size) { | ||||||
|  |     U val = _contiguous_strided_reduce<T, U, Op, N_READS>( | ||||||
|  |               in, | ||||||
|  |               local_data, | ||||||
|  |               in_idx, | ||||||
|  |               reduction_size, | ||||||
|  |               reduction_stride, | ||||||
|  |               tid.xy, | ||||||
|  |               lid.xy, | ||||||
|  |               lsize.xy); | ||||||
|  |  | ||||||
|  |     // Write out reduction results generated by threadgroups working on specific output element, contiguously. | ||||||
|  |     if (lid.y == 0) { | ||||||
|  |       uint tgsize_y = ceildiv(gsize.y, lsize.y); | ||||||
|  |       uint tgsize_z = ceildiv(gsize.z, lsize.z); | ||||||
|  |       out[tgsize_y * tgsize_z * gid.x + tgsize_y * tid.z + tid.y] = val; | ||||||
|  |     } | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -312,6 +499,23 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS> | |||||||
|       uint3 lid [[thread_position_in_threadgroup]], \ |       uint3 lid [[thread_position_in_threadgroup]], \ | ||||||
|       uint3 lsize [[threads_per_threadgroup]]); |       uint3 lsize [[threads_per_threadgroup]]); | ||||||
|  |  | ||||||
|  | #define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \ | ||||||
|  |   template [[host_name("col_reduce_general_no_atomics_" #name)]] \ | ||||||
|  |   [[kernel]] void col_reduce_general_no_atomics<itype, otype, op>( \ | ||||||
|  |       const device itype *in [[buffer(0)]], \ | ||||||
|  |       device otype *out [[buffer(1)]], \ | ||||||
|  |       const constant size_t& reduction_size [[buffer(2)]], \ | ||||||
|  |       const constant size_t& reduction_stride [[buffer(3)]], \ | ||||||
|  |       const constant size_t& out_size [[buffer(4)]], \ | ||||||
|  |       const constant int* shape [[buffer(5)]],  \ | ||||||
|  |       const constant size_t* strides [[buffer(6)]],  \ | ||||||
|  |       const constant int& ndim [[buffer(7)]],  \ | ||||||
|  |       threadgroup otype *local_data [[threadgroup(0)]], \ | ||||||
|  |       uint3 tid [[threadgroup_position_in_grid]], \ | ||||||
|  |       uint3 lid [[thread_position_in_threadgroup]], \ | ||||||
|  |       uint3 gid [[thread_position_in_grid]], \ | ||||||
|  |       uint3 lsize [[threads_per_threadgroup]], \ | ||||||
|  |       uint3 gsize [[threads_per_grid]]); | ||||||
|  |  | ||||||
| /////////////////////////////////////////////////////////////////////////////// | /////////////////////////////////////////////////////////////////////////////// | ||||||
| // Instantiations | // Instantiations | ||||||
| @@ -322,6 +526,15 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS> | |||||||
|   instantiate_row_reduce_general(name, itype, otype, op) \ |   instantiate_row_reduce_general(name, itype, otype, op) \ | ||||||
|   instantiate_col_reduce_general(name, itype, otype, op) |   instantiate_col_reduce_general(name, itype, otype, op) | ||||||
|  |  | ||||||
|  | #define instantiate_reduce_no_atomics(name, itype, otype, op) \ | ||||||
|  |   instantiate_all_reduce_no_atomics(name, itype, otype, op) \ | ||||||
|  |   instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \ | ||||||
|  |   instantiate_col_reduce_general_no_atomics(name, itype, otype, op) | ||||||
|  |  | ||||||
|  | #define instantiate_same_reduce_no_atomics(name, tname, type, op) \ | ||||||
|  |   instantiate_init_reduce(name ##tname, type, op<type>) \ | ||||||
|  |   instantiate_reduce_no_atomics(name ##tname, type, type, op<type>) | ||||||
|  |  | ||||||
| #define instantiate_same_reduce(name, tname, type, op) \ | #define instantiate_same_reduce(name, tname, type, op) \ | ||||||
|   instantiate_init_reduce(name ##tname, type, op<type>) \ |   instantiate_init_reduce(name ##tname, type, op<type>) \ | ||||||
|   instantiate_reduce(name ##tname, type, type, op<type>) |   instantiate_reduce(name ##tname, type, type, op<type>) | ||||||
| @@ -353,6 +566,9 @@ instantiate_same_reduce(sum, int32, int32_t, Sum) | |||||||
| instantiate_same_reduce(sum, float16, half, Sum) | instantiate_same_reduce(sum, float16, half, Sum) | ||||||
| instantiate_same_reduce(sum, float32, float, Sum) | instantiate_same_reduce(sum, float32, float, Sum) | ||||||
|  |  | ||||||
|  | instantiate_same_reduce_no_atomics(sum, int64, int64_t, Sum) | ||||||
|  | instantiate_same_reduce_no_atomics(sum, uint64, uint64_t, Sum) | ||||||
|  |  | ||||||
| instantiate_same_reduce(prod, uint8, uint8_t, Prod) | instantiate_same_reduce(prod, uint8, uint8_t, Prod) | ||||||
| instantiate_same_reduce(prod, uint16, uint16_t, Prod) | instantiate_same_reduce(prod, uint16, uint16_t, Prod) | ||||||
| instantiate_same_reduce(prod, uint32, uint32_t, Prod) | instantiate_same_reduce(prod, uint32, uint32_t, Prod) | ||||||
| @@ -362,6 +578,9 @@ instantiate_same_reduce(prod, int32, int32_t, Prod) | |||||||
| instantiate_same_reduce(prod, float16, half, Prod) | instantiate_same_reduce(prod, float16, half, Prod) | ||||||
| instantiate_same_reduce(prod, float32, float, Prod) | instantiate_same_reduce(prod, float32, float, Prod) | ||||||
|  |  | ||||||
|  | instantiate_same_reduce_no_atomics(prod, int64, int64_t, Prod) | ||||||
|  | instantiate_same_reduce_no_atomics(prod, uint64, uint64_t, Prod) | ||||||
|  |  | ||||||
| instantiate_same_reduce(sum, bfloat16, bfloat16_t, Sum) | instantiate_same_reduce(sum, bfloat16, bfloat16_t, Sum) | ||||||
| instantiate_same_reduce(prod, bfloat16, bfloat16_t, Prod) | instantiate_same_reduce(prod, bfloat16, bfloat16_t, Prod) | ||||||
|  |  | ||||||
| @@ -381,6 +600,9 @@ instantiate_same_reduce(min_, int32, int32_t, Min) | |||||||
| instantiate_same_reduce(min_, float16, half, Min) | instantiate_same_reduce(min_, float16, half, Min) | ||||||
| instantiate_same_reduce(min_, float32, float, Min) | instantiate_same_reduce(min_, float32, float, Min) | ||||||
|  |  | ||||||
|  | instantiate_same_reduce_no_atomics(min_, int64, int64_t, Min) | ||||||
|  | instantiate_same_reduce_no_atomics(min_, uint64, uint64_t, Min) | ||||||
|  |  | ||||||
| instantiate_same_reduce(max_, uint8, uint8_t, Max) | instantiate_same_reduce(max_, uint8, uint8_t, Max) | ||||||
| instantiate_same_reduce(max_, uint16, uint16_t, Max) | instantiate_same_reduce(max_, uint16, uint16_t, Max) | ||||||
| instantiate_same_reduce(max_, uint32, uint32_t, Max) | instantiate_same_reduce(max_, uint32, uint32_t, Max) | ||||||
| @@ -390,5 +612,8 @@ instantiate_same_reduce(max_, int32, int32_t, Max) | |||||||
| instantiate_same_reduce(max_, float16, half, Max) | instantiate_same_reduce(max_, float16, half, Max) | ||||||
| instantiate_same_reduce(max_, float32, float, Max) | instantiate_same_reduce(max_, float32, float, Max) | ||||||
|  |  | ||||||
|  | instantiate_same_reduce_no_atomics(max_, int64, int64_t, Max) | ||||||
|  | instantiate_same_reduce_no_atomics(max_, uint64, uint64_t, Max) | ||||||
|  |  | ||||||
| instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min) | instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min) | ||||||
| instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max) | instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max) | ||||||
|   | |||||||
| @@ -256,3 +256,21 @@ inline bfloat16_t log1p(bfloat16_t x) { | |||||||
|  |  | ||||||
|   return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f))); |   return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f))); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | /////////////////////////////////////////////////////////////////////////////// | ||||||
|  | // SIMD shuffle ops | ||||||
|  | /////////////////////////////////////////////////////////////////////////////// | ||||||
|  |  | ||||||
|  | inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) { | ||||||
|  |   return as_type<uint64_t>( | ||||||
|  |       metal::simd_shuffle_down(as_type<uint2>(data), delta)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) { | ||||||
|  |   return as_type<int64_t>( | ||||||
|  |       metal::simd_shuffle_down(as_type<uint2>(data), delta)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline bool simd_shuffle_down(bool data, uint16_t delta) { | ||||||
|  |   return simd_shuffle_down(static_cast<uint32_t>(data), delta); | ||||||
|  | } | ||||||
| @@ -28,35 +28,40 @@ inline auto safe_divup(size_t n, size_t m) { | |||||||
|   return safe_div(n, m) * m; |   return safe_div(n, m) * m; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | inline bool is_64b_int(Dtype dtype) { | ||||||
|  |   return dtype == int64 || dtype == uint64; | ||||||
|  | } | ||||||
|  |  | ||||||
| // All Reduce | // All Reduce | ||||||
| void all_reduce_dispatch( | void all_reduce_dispatch( | ||||||
|     const array& in, |     const array& in, | ||||||
|     array& out, |     array& out, | ||||||
|     const std::string& op_name, |     const std::string& op_name, | ||||||
|     MTL::ComputeCommandEncoder* compute_encoder, |     MTL::ComputeCommandEncoder* compute_encoder, | ||||||
|     metal::Device& d) { |     metal::Device& d, | ||||||
|   // Get kernel and encode buffers |     const Stream& s) { | ||||||
|   size_t in_size = in.size(); |   Dtype out_dtype = out.dtype(); | ||||||
|   auto kernel = d.get_kernel("all_reduce_" + op_name + type_to_name(in)); |   bool is_out_64b_int = is_64b_int(out_dtype); | ||||||
|  |   auto kernel = (is_out_64b_int) | ||||||
|  |       ? d.get_kernel("all_reduce_no_atomics_" + op_name + type_to_name(in)) | ||||||
|  |       : d.get_kernel("all_reduce_" + op_name + type_to_name(in)); | ||||||
|  |  | ||||||
|   compute_encoder->setComputePipelineState(kernel); |   compute_encoder->setComputePipelineState(kernel); | ||||||
|   set_array_buffer(compute_encoder, in, 0); |  | ||||||
|   set_array_buffer(compute_encoder, out, 1); |  | ||||||
|   compute_encoder->setBytes(&in_size, sizeof(size_t), 2); |  | ||||||
|  |  | ||||||
|   // Set grid dimensions |  | ||||||
|  |  | ||||||
|   // We make sure each thread has enough to do by making it read in |   // We make sure each thread has enough to do by making it read in | ||||||
|   // at least n_reads inputs |   // at least n_reads inputs | ||||||
|   int n_reads = REDUCE_N_READS; |   int n_reads = REDUCE_N_READS; | ||||||
|  |   size_t in_size = in.size(); | ||||||
|  |  | ||||||
|   // mod_in_size gives us the groups of n_reads needed to go over the entire |   // mod_in_size gives us the groups of n_reads needed to go over the entire | ||||||
|   // input |   // input | ||||||
|   uint mod_in_size = (in_size + n_reads - 1) / n_reads; |   uint mod_in_size = (in_size + n_reads - 1) / n_reads; | ||||||
|  |  | ||||||
|   NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); |   NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); | ||||||
|   thread_group_size = |   thread_group_size = | ||||||
|       mod_in_size > thread_group_size ? thread_group_size : mod_in_size; |       mod_in_size > thread_group_size ? thread_group_size : mod_in_size; | ||||||
|  |   uint simd_size = kernel->threadExecutionWidth(); | ||||||
|  |   thread_group_size = | ||||||
|  |       ((thread_group_size + simd_size - 1) / simd_size) * simd_size; | ||||||
|  |  | ||||||
|   // If the number of thread groups needed exceeds 1024, we reuse threads groups |   // If the number of thread groups needed exceeds 1024, we reuse threads groups | ||||||
|   uint n_thread_groups = safe_div(mod_in_size, thread_group_size); |   uint n_thread_groups = safe_div(mod_in_size, thread_group_size); | ||||||
| @@ -66,7 +71,52 @@ void all_reduce_dispatch( | |||||||
|   MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); |   MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); | ||||||
|   MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); |   MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); | ||||||
|  |  | ||||||
|   compute_encoder->dispatchThreads(grid_dims, group_dims); |   // Encode buffers and dispatch | ||||||
|  |   if (is_out_64b_int == false || n_thread_groups == 1) { | ||||||
|  |     set_array_buffer(compute_encoder, in, 0); | ||||||
|  |     set_array_buffer(compute_encoder, out, 1); | ||||||
|  |     compute_encoder->setBytes(&in_size, sizeof(size_t), 2); | ||||||
|  |     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||||
|  |  | ||||||
|  |   } else { | ||||||
|  |     // Allocate intermediate array to store partial reduction results | ||||||
|  |     size_t intermediate_size = n_thread_groups; | ||||||
|  |     array intermediate = | ||||||
|  |         array({static_cast<int>(intermediate_size)}, out_dtype, nullptr, {}); | ||||||
|  |     intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); | ||||||
|  |     std::vector<array> intermediates = {intermediate}; | ||||||
|  |  | ||||||
|  |     // First dispatch | ||||||
|  |     set_array_buffer(compute_encoder, in, 0); | ||||||
|  |     set_array_buffer(compute_encoder, intermediate, 1); | ||||||
|  |     compute_encoder->setBytes(&in_size, sizeof(size_t), 2); | ||||||
|  |     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||||
|  |  | ||||||
|  |     // Second pass to reduce intermediate reduction results written to DRAM | ||||||
|  |     set_array_buffer(compute_encoder, intermediate, 0); | ||||||
|  |     set_array_buffer(compute_encoder, out, 1); | ||||||
|  |     compute_encoder->setBytes(&intermediate_size, sizeof(size_t), 2); | ||||||
|  |  | ||||||
|  |     mod_in_size = (intermediate_size + n_reads - 1) / n_reads; | ||||||
|  |  | ||||||
|  |     thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); | ||||||
|  |     thread_group_size = | ||||||
|  |         mod_in_size > thread_group_size ? thread_group_size : mod_in_size; | ||||||
|  |     thread_group_size = | ||||||
|  |         ((thread_group_size + simd_size - 1) / simd_size) * simd_size; | ||||||
|  |  | ||||||
|  |     // If the number of thread groups needed exceeds 1024, we reuse threads | ||||||
|  |     // groups | ||||||
|  |     nthreads = thread_group_size; | ||||||
|  |     group_dims = MTL::Size(thread_group_size, 1, 1); | ||||||
|  |     grid_dims = MTL::Size(nthreads, 1, 1); | ||||||
|  |     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||||
|  |  | ||||||
|  |     d.get_command_buffer(s.index)->addCompletedHandler( | ||||||
|  |         [intermediates](MTL::CommandBuffer*) mutable { | ||||||
|  |           intermediates.clear(); | ||||||
|  |         }); | ||||||
|  |   } | ||||||
| } | } | ||||||
|  |  | ||||||
| void row_reduce_general_dispatch( | void row_reduce_general_dispatch( | ||||||
| @@ -76,22 +126,31 @@ void row_reduce_general_dispatch( | |||||||
|     const ReductionPlan& plan, |     const ReductionPlan& plan, | ||||||
|     const std::vector<int>& axes, |     const std::vector<int>& axes, | ||||||
|     MTL::ComputeCommandEncoder* compute_encoder, |     MTL::ComputeCommandEncoder* compute_encoder, | ||||||
|     metal::Device& d) { |     metal::Device& d, | ||||||
|   auto kernel = |     const Stream& s) { | ||||||
|       d.get_kernel("row_reduce_general_" + op_name + type_to_name(in)); |   Dtype out_dtype = out.dtype(); | ||||||
|  |   bool is_out_64b_int = is_64b_int(out_dtype); | ||||||
|  |   auto kernel = (is_out_64b_int) | ||||||
|  |       ? d.get_kernel( | ||||||
|  |             "row_reduce_general_no_atomics_" + op_name + type_to_name(in)) | ||||||
|  |       : d.get_kernel("row_reduce_general_" + op_name + type_to_name(in)); | ||||||
|  |  | ||||||
|  |   compute_encoder->setComputePipelineState(kernel); | ||||||
|  |  | ||||||
|   // Prepare the arguments for the kernel |   // Prepare the arguments for the kernel | ||||||
|   int n_reads = REDUCE_N_READS; |   int n_reads = REDUCE_N_READS; | ||||||
|   size_t reduction_size = plan.shape.back(); |   size_t reduction_size = plan.shape.back(); | ||||||
|   size_t out_size = out.size(); |  | ||||||
|   auto shape = plan.shape; |   auto shape = plan.shape; | ||||||
|   auto strides = plan.strides; |   auto strides = plan.strides; | ||||||
|  |  | ||||||
|   shape.pop_back(); |   shape.pop_back(); | ||||||
|   strides.pop_back(); |   strides.pop_back(); | ||||||
|  |  | ||||||
|   size_t non_row_reductions = 1; |   size_t non_row_reductions = 1; | ||||||
|   for (auto s : shape) { |   for (auto s : shape) { | ||||||
|     non_row_reductions *= static_cast<size_t>(s); |     non_row_reductions *= static_cast<size_t>(s); | ||||||
|   } |   } | ||||||
|  |   size_t out_size = out.size(); | ||||||
|   auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes); |   auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes); | ||||||
|   for (auto s : rem_shape) { |   for (auto s : rem_shape) { | ||||||
|     shape.push_back(s); |     shape.push_back(s); | ||||||
| @@ -101,16 +160,6 @@ void row_reduce_general_dispatch( | |||||||
|   } |   } | ||||||
|   int ndim = shape.size(); |   int ndim = shape.size(); | ||||||
|  |  | ||||||
|   // Set the arguments for the kernel |  | ||||||
|   compute_encoder->setComputePipelineState(kernel); |  | ||||||
|   set_array_buffer(compute_encoder, in, 0); |  | ||||||
|   set_array_buffer(compute_encoder, out, 1); |  | ||||||
|   compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); |  | ||||||
|   compute_encoder->setBytes(&out_size, sizeof(size_t), 3); |  | ||||||
|   compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4); |  | ||||||
|   compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 5); |  | ||||||
|   compute_encoder->setBytes(&ndim, sizeof(int), 6); |  | ||||||
|  |  | ||||||
|   // Each thread group is responsible for 1 output |   // Each thread group is responsible for 1 output | ||||||
|   NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); |   NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); | ||||||
|   thread_group_size = |   thread_group_size = | ||||||
| @@ -127,7 +176,88 @@ void row_reduce_general_dispatch( | |||||||
|   MTL::Size grid_dims = MTL::Size(n_threads, non_row_reductions, 1); |   MTL::Size grid_dims = MTL::Size(n_threads, non_row_reductions, 1); | ||||||
|   MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); |   MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); | ||||||
|  |  | ||||||
|   compute_encoder->dispatchThreads(grid_dims, group_dims); |   if (is_out_64b_int == false || non_row_reductions == 1) { | ||||||
|  |     // Set the arguments for the kernel | ||||||
|  |     set_array_buffer(compute_encoder, in, 0); | ||||||
|  |     set_array_buffer(compute_encoder, out, 1); | ||||||
|  |     compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); | ||||||
|  |     compute_encoder->setBytes(&out_size, sizeof(size_t), 3); | ||||||
|  |     compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4); | ||||||
|  |     compute_encoder->setBytes( | ||||||
|  |         strides.data(), strides.size() * sizeof(size_t), 5); | ||||||
|  |     compute_encoder->setBytes(&ndim, sizeof(int), 6); | ||||||
|  |     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||||
|  |  | ||||||
|  |   } else { | ||||||
|  |     // Allocate intermediate array to store partial reduction results | ||||||
|  |     array intermediate = array( | ||||||
|  |         {static_cast<int>(out.size()), static_cast<int>(non_row_reductions)}, | ||||||
|  |         out_dtype, | ||||||
|  |         nullptr, | ||||||
|  |         {}); | ||||||
|  |     intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); | ||||||
|  |     std::vector<array> intermediates = {intermediate}; | ||||||
|  |  | ||||||
|  |     // Set the arguments for the kernel | ||||||
|  |     set_array_buffer(compute_encoder, in, 0); | ||||||
|  |     set_array_buffer(compute_encoder, intermediate, 1); | ||||||
|  |     compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); | ||||||
|  |     compute_encoder->setBytes(&out_size, sizeof(size_t), 3); | ||||||
|  |     compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4); | ||||||
|  |     compute_encoder->setBytes( | ||||||
|  |         strides.data(), strides.size() * sizeof(size_t), 5); | ||||||
|  |     compute_encoder->setBytes(&ndim, sizeof(int), 6); | ||||||
|  |     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||||
|  |  | ||||||
|  |     // Set up second dispatch | ||||||
|  |     reduction_size = non_row_reductions; | ||||||
|  |     out_size = 1; | ||||||
|  |  | ||||||
|  |     // Shape of axes that aren't participating in reduction remains unchanged. | ||||||
|  |     std::vector<int> new_shape = rem_shape; | ||||||
|  |  | ||||||
|  |     // Update their strides since they'll be different post partial reduction in | ||||||
|  |     // first compute dispatch. | ||||||
|  |     std::vector<size_t> new_strides = rem_strides; | ||||||
|  |     new_strides.back() = reduction_size; | ||||||
|  |     for (int i = new_shape.size() - 2; i >= 0; i--) { | ||||||
|  |       new_strides[i] = new_shape[i + 1] * new_strides[i + 1]; | ||||||
|  |     } | ||||||
|  |     ndim = new_shape.size(); | ||||||
|  |  | ||||||
|  |     // Set the arguments for the kernel | ||||||
|  |     set_array_buffer(compute_encoder, intermediate, 0); | ||||||
|  |     set_array_buffer(compute_encoder, out, 1); | ||||||
|  |     compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); | ||||||
|  |     compute_encoder->setBytes(&out_size, sizeof(size_t), 3); | ||||||
|  |     compute_encoder->setBytes( | ||||||
|  |         new_shape.data(), new_shape.size() * sizeof(int), 4); | ||||||
|  |     compute_encoder->setBytes( | ||||||
|  |         new_strides.data(), new_strides.size() * sizeof(size_t), 5); | ||||||
|  |     compute_encoder->setBytes(&ndim, sizeof(int), 6); | ||||||
|  |  | ||||||
|  |     // Each thread group is responsible for 1 output | ||||||
|  |     thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); | ||||||
|  |     thread_group_size = | ||||||
|  |         std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size); | ||||||
|  |  | ||||||
|  |     // Align thread group size with simd_size | ||||||
|  |     thread_group_size = | ||||||
|  |         (thread_group_size + simd_size - 1) / simd_size * simd_size; | ||||||
|  |     assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup()); | ||||||
|  |  | ||||||
|  |     // Launch enough thread groups for each output | ||||||
|  |     n_threads = thread_group_size; | ||||||
|  |     grid_dims = MTL::Size(n_threads, out.size(), 1); | ||||||
|  |     group_dims = MTL::Size(thread_group_size, 1, 1); | ||||||
|  |  | ||||||
|  |     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||||
|  |  | ||||||
|  |     d.get_command_buffer(s.index)->addCompletedHandler( | ||||||
|  |         [intermediates](MTL::CommandBuffer*) mutable { | ||||||
|  |           intermediates.clear(); | ||||||
|  |         }); | ||||||
|  |   } | ||||||
| } | } | ||||||
|  |  | ||||||
| void strided_reduce_general_dispatch( | void strided_reduce_general_dispatch( | ||||||
| @@ -137,9 +267,16 @@ void strided_reduce_general_dispatch( | |||||||
|     const ReductionPlan& plan, |     const ReductionPlan& plan, | ||||||
|     const std::vector<int>& axes, |     const std::vector<int>& axes, | ||||||
|     MTL::ComputeCommandEncoder* compute_encoder, |     MTL::ComputeCommandEncoder* compute_encoder, | ||||||
|     metal::Device& d) { |     metal::Device& d, | ||||||
|   auto kernel = |     const Stream& s) { | ||||||
|       d.get_kernel("col_reduce_general_" + op_name + type_to_name(in)); |   Dtype out_dtype = out.dtype(); | ||||||
|  |   bool is_out_64b_int = is_64b_int(out_dtype); | ||||||
|  |   auto kernel = (is_out_64b_int) | ||||||
|  |       ? d.get_kernel( | ||||||
|  |             "col_reduce_general_no_atomics_" + op_name + type_to_name(in)) | ||||||
|  |       : d.get_kernel("col_reduce_general_" + op_name + type_to_name(in)); | ||||||
|  |  | ||||||
|  |   compute_encoder->setComputePipelineState(kernel); | ||||||
|  |  | ||||||
|   // Prepare the arguments for the kernel |   // Prepare the arguments for the kernel | ||||||
|   size_t reduction_size = plan.shape.back(); |   size_t reduction_size = plan.shape.back(); | ||||||
| @@ -162,19 +299,7 @@ void strided_reduce_general_dispatch( | |||||||
|   } |   } | ||||||
|   int ndim = shape.size(); |   int ndim = shape.size(); | ||||||
|  |  | ||||||
|   // Set the arguments for the kernel |  | ||||||
|   compute_encoder->setComputePipelineState(kernel); |  | ||||||
|   set_array_buffer(compute_encoder, in, 0); |  | ||||||
|   set_array_buffer(compute_encoder, out, 1); |  | ||||||
|   compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); |  | ||||||
|   compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3); |  | ||||||
|   compute_encoder->setBytes(&out_size, sizeof(size_t), 4); |  | ||||||
|   compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5); |  | ||||||
|   compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 6); |  | ||||||
|   compute_encoder->setBytes(&ndim, sizeof(int), 7); |  | ||||||
|  |  | ||||||
|   // Select block dimensions |   // Select block dimensions | ||||||
|  |  | ||||||
|   // Each thread reads 16 inputs to give it more work |   // Each thread reads 16 inputs to give it more work | ||||||
|   uint n_inputs_per_thread = REDUCE_N_READS; |   uint n_inputs_per_thread = REDUCE_N_READS; | ||||||
|   uint n_threads_per_output = |   uint n_threads_per_output = | ||||||
| @@ -183,14 +308,22 @@ void strided_reduce_general_dispatch( | |||||||
|   // We spread outputs over the x dimension and inputs over the y dimension |   // We spread outputs over the x dimension and inputs over the y dimension | ||||||
|   // Threads with the same lid.x in a given threadgroup work on the same |   // Threads with the same lid.x in a given threadgroup work on the same | ||||||
|   // output and each thread in the y dimension accumulates for that output |   // output and each thread in the y dimension accumulates for that output | ||||||
|  |  | ||||||
|  |   // Threads with same lid.x, i.e. each column of threads work on same output | ||||||
|   uint threadgroup_dim_x = std::min(out_size, 128ul); |   uint threadgroup_dim_x = std::min(out_size, 128ul); | ||||||
|  |  | ||||||
|  |   // Number of threads along y, is dependent on number of reductions needed. | ||||||
|   uint threadgroup_dim_y = |   uint threadgroup_dim_y = | ||||||
|       kernel->maxTotalThreadsPerThreadgroup() / threadgroup_dim_x; |       kernel->maxTotalThreadsPerThreadgroup() / threadgroup_dim_x; | ||||||
|   threadgroup_dim_y = std::min(n_threads_per_output, threadgroup_dim_y); |   threadgroup_dim_y = std::min(n_threads_per_output, threadgroup_dim_y); | ||||||
|  |  | ||||||
|  |   // Derive number of thread groups along x, based on how many threads we need | ||||||
|  |   // along x | ||||||
|   uint n_threadgroups_x = |   uint n_threadgroups_x = | ||||||
|       (out_size + threadgroup_dim_x - 1) / threadgroup_dim_x; |       (out_size + threadgroup_dim_x - 1) / threadgroup_dim_x; | ||||||
|  |  | ||||||
|  |   // Derive number of thread groups along y based on how many threads we need | ||||||
|  |   // along y | ||||||
|   uint n_threadgroups_y = |   uint n_threadgroups_y = | ||||||
|       (n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y; |       (n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y; | ||||||
|  |  | ||||||
| @@ -199,18 +332,122 @@ void strided_reduce_general_dispatch( | |||||||
|       MTL::Size(n_threadgroups_x, n_threadgroups_y, non_col_reductions); |       MTL::Size(n_threadgroups_x, n_threadgroups_y, non_col_reductions); | ||||||
|   MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1); |   MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1); | ||||||
|  |  | ||||||
|   // We set shared memory to be exploited here for reductions within a |   if (is_out_64b_int == false) { | ||||||
|   // threadgroup - each thread must be able to update its accumulated output |     // Set the arguments for the kernel | ||||||
|   // Note: Each threadgroup should have 32kB of data in threadgroup memory |     set_array_buffer(compute_encoder, in, 0); | ||||||
|   //       and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design |     set_array_buffer(compute_encoder, out, 1); | ||||||
|   //       This should be fine for floats, but we might need to revisit |     compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); | ||||||
|   //       if we ever come to doubles. In that case, we should also cut |     compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3); | ||||||
|   //       down the number of threads we launch in a threadgroup |     compute_encoder->setBytes(&out_size, sizeof(size_t), 4); | ||||||
|   compute_encoder->setThreadgroupMemoryLength( |     compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5); | ||||||
|       safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16), |     compute_encoder->setBytes( | ||||||
|       0); |         strides.data(), strides.size() * sizeof(size_t), 6); | ||||||
|  |     compute_encoder->setBytes(&ndim, sizeof(int), 7); | ||||||
|  |  | ||||||
|   compute_encoder->dispatchThreadgroups(grid_dims, group_dims); |     // We set shared memory to be exploited here for reductions within a | ||||||
|  |     // threadgroup - each thread must be able to update its accumulated output | ||||||
|  |     // Note: Each threadgroup should have 32kB of data in threadgroup memory | ||||||
|  |     //       and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design | ||||||
|  |     //       This should be fine for floats, but we might need to revisit | ||||||
|  |     //       if we ever come to doubles. In that case, we should also cut | ||||||
|  |     //       down the number of threads we launch in a threadgroup | ||||||
|  |     compute_encoder->setThreadgroupMemoryLength( | ||||||
|  |         safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16), | ||||||
|  |         0); | ||||||
|  |     compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||||
|  |  | ||||||
|  |   } else { | ||||||
|  |     // Allocate intermediate array to store reduction results from all thread | ||||||
|  |     // groups | ||||||
|  |     array intermediate = array( | ||||||
|  |         {static_cast<int>(out.size()), | ||||||
|  |          static_cast<int>(n_threadgroups_y * non_col_reductions)}, | ||||||
|  |         out_dtype, | ||||||
|  |         nullptr, | ||||||
|  |         {}); | ||||||
|  |     intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); | ||||||
|  |     std::vector<array> intermediates = {intermediate}; | ||||||
|  |  | ||||||
|  |     // Set the arguments for the kernel | ||||||
|  |     set_array_buffer(compute_encoder, in, 0); | ||||||
|  |     set_array_buffer(compute_encoder, intermediate, 1); | ||||||
|  |     compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); | ||||||
|  |     compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3); | ||||||
|  |     compute_encoder->setBytes(&out_size, sizeof(size_t), 4); | ||||||
|  |     compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5); | ||||||
|  |     compute_encoder->setBytes( | ||||||
|  |         strides.data(), strides.size() * sizeof(size_t), 6); | ||||||
|  |     compute_encoder->setBytes(&ndim, sizeof(int), 7); | ||||||
|  |  | ||||||
|  |     // We set shared memory to be exploited here for reductions within a | ||||||
|  |     // threadgroup - each thread must be able to update its accumulated output | ||||||
|  |     // Note: Each threadgroup should have 32kB of data in threadgroup memory | ||||||
|  |     //       and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design | ||||||
|  |     //       This should be fine for floats, but we might need to revisit | ||||||
|  |     //       if we ever come to doubles. In that case, we should also cut | ||||||
|  |     //       down the number of threads we launch in a threadgroup | ||||||
|  |     compute_encoder->setThreadgroupMemoryLength( | ||||||
|  |         safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16), | ||||||
|  |         0); | ||||||
|  |     compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||||
|  |  | ||||||
|  |     // Perform second pass of reductions | ||||||
|  |     // Reduce results of threadgroups along y, z from first pass, that | ||||||
|  |     // collectively work on each output element. | ||||||
|  |     reduction_size = n_threadgroups_y * non_col_reductions; | ||||||
|  |     out_size = 1; | ||||||
|  |  | ||||||
|  |     // Shape of axes that aren't participating in reduction remains unchanged. | ||||||
|  |     std::vector<int> new_shape = rem_shape; | ||||||
|  |  | ||||||
|  |     // Update their strides since they'll be different after a partial reduction | ||||||
|  |     // post first compute dispatch. | ||||||
|  |     std::vector<size_t> new_strides = rem_strides; | ||||||
|  |     new_strides.back() = reduction_size; | ||||||
|  |     for (int i = new_shape.size() - 2; i >= 0; i--) { | ||||||
|  |       new_strides[i] = new_shape[i + 1] * new_strides[i + 1]; | ||||||
|  |     } | ||||||
|  |     ndim = new_shape.size(); | ||||||
|  |  | ||||||
|  |     auto row_reduce_kernel = d.get_kernel( | ||||||
|  |         "row_reduce_general_no_atomics_" + op_name + | ||||||
|  |         type_to_name(intermediate)); | ||||||
|  |     compute_encoder->setComputePipelineState(row_reduce_kernel); | ||||||
|  |     set_array_buffer(compute_encoder, intermediate, 0); | ||||||
|  |     set_array_buffer(compute_encoder, out, 1); | ||||||
|  |     compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); | ||||||
|  |     compute_encoder->setBytes(&out_size, sizeof(size_t), 3); | ||||||
|  |     compute_encoder->setBytes( | ||||||
|  |         new_shape.data(), new_shape.size() * sizeof(int), 4); | ||||||
|  |     compute_encoder->setBytes( | ||||||
|  |         new_strides.data(), new_strides.size() * sizeof(size_t), 5); | ||||||
|  |     compute_encoder->setBytes(&ndim, sizeof(int), 6); | ||||||
|  |  | ||||||
|  |     // Each thread group is responsible for 1 output | ||||||
|  |     size_t n_reads = REDUCE_N_READS; | ||||||
|  |     size_t thread_group_size = | ||||||
|  |         row_reduce_kernel->maxTotalThreadsPerThreadgroup(); | ||||||
|  |     thread_group_size = | ||||||
|  |         std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size); | ||||||
|  |  | ||||||
|  |     // Align thread group size with simd_size | ||||||
|  |     uint simd_size = row_reduce_kernel->threadExecutionWidth(); | ||||||
|  |     thread_group_size = | ||||||
|  |         (thread_group_size + simd_size - 1) / simd_size * simd_size; | ||||||
|  |     assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup()); | ||||||
|  |  | ||||||
|  |     // Launch enough thread groups for each output | ||||||
|  |     uint n_threads = thread_group_size; | ||||||
|  |     grid_dims = MTL::Size(n_threads, out.size(), 1); | ||||||
|  |     group_dims = MTL::Size(thread_group_size, 1, 1); | ||||||
|  |  | ||||||
|  |     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||||
|  |  | ||||||
|  |     d.get_command_buffer(s.index)->addCompletedHandler( | ||||||
|  |         [intermediates](MTL::CommandBuffer*) mutable { | ||||||
|  |           intermediates.clear(); | ||||||
|  |         }); | ||||||
|  |   } | ||||||
| } | } | ||||||
|  |  | ||||||
| } // namespace | } // namespace | ||||||
| @@ -223,14 +460,6 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|   assert(inputs.size() == 1); |   assert(inputs.size() == 1); | ||||||
|   array in = inputs[0]; |   array in = inputs[0]; | ||||||
|  |  | ||||||
|   // TODO: Allow specific row and column reductions with types disabled |  | ||||||
|   // due to atomics ? |  | ||||||
|   if (size_of(in.dtype()) == 8) { |  | ||||||
|     std::ostringstream msg; |  | ||||||
|     msg << "[Reduce::eval_gpu] Does not support " << in.dtype(); |  | ||||||
|     throw std::runtime_error(msg.str()); |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   // Make sure no identity reductions trickle down here |   // Make sure no identity reductions trickle down here | ||||||
|   assert(!axes_.empty()); |   assert(!axes_.empty()); | ||||||
|  |  | ||||||
| @@ -297,7 +526,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|     // Reducing over everything and the data is all there no broadcasting or |     // Reducing over everything and the data is all there no broadcasting or | ||||||
|     // slicing etc. |     // slicing etc. | ||||||
|     if (plan.type == ContiguousAllReduce) { |     if (plan.type == ContiguousAllReduce) { | ||||||
|       all_reduce_dispatch(in, out, op_name, compute_encoder, d); |       all_reduce_dispatch(in, out, op_name, compute_encoder, d, s); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // At least the last dimension is row contiguous and we are reducing over |     // At least the last dimension is row contiguous and we are reducing over | ||||||
| @@ -305,7 +534,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|     else if ( |     else if ( | ||||||
|         plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) { |         plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) { | ||||||
|       row_reduce_general_dispatch( |       row_reduce_general_dispatch( | ||||||
|           in, out, op_name, plan, axes_, compute_encoder, d); |           in, out, op_name, plan, axes_, compute_encoder, d, s); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // At least the last two dimensions are contiguous and we are doing a |     // At least the last two dimensions are contiguous and we are doing a | ||||||
| @@ -314,7 +543,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|         plan.type == ContiguousStridedReduce || |         plan.type == ContiguousStridedReduce || | ||||||
|         plan.type == GeneralStridedReduce) { |         plan.type == GeneralStridedReduce) { | ||||||
|       strided_reduce_general_dispatch( |       strided_reduce_general_dispatch( | ||||||
|           in, out, op_name, plan, axes_, compute_encoder, d); |           in, out, op_name, plan, axes_, compute_encoder, d, s); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if (!copies.empty()) { |     if (!copies.empty()) { | ||||||
|   | |||||||
| @@ -55,6 +55,8 @@ class TestReduce(mlx_tests.MLXTestCase): | |||||||
|             "uint8", |             "uint8", | ||||||
|             "uint16", |             "uint16", | ||||||
|             "uint32", |             "uint32", | ||||||
|  |             "int64", | ||||||
|  |             "uint64", | ||||||
|         ] |         ] | ||||||
|         float_dtypes = ["float32"] |         float_dtypes = ["float32"] | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Vijay Krish
					Vijay Krish