7template <
typename T,
typename U,
typename Op,
int N_READS = REDUCE_N_READS>
10 const device
size_t& in_size,
14 U total_val = Op::init;
16 if (gid * N_READS < in_size) {
20 for (; r < (int)
ceildiv(in_size, grid_size * N_READS) - 1; r++) {
21 U vals[N_READS] = {
op.init};
23 for (
int i = 0; i < N_READS; i++) {
24 vals[i] =
static_cast<U
>(in[i]);
26 for (
int i = 0; i < N_READS; i++) {
27 total_val =
op(vals[i], total_val);
30 in += grid_size * N_READS;
34 size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
35 if (curr_idx < in_size) {
36 int max_reads = in_size - curr_idx;
39 for (
int i = 0, idx = 0; i < N_READS; i++, idx++) {
40 idx = idx < max_reads ? idx : max_reads - 1;
43 for (
int i = 0; i < N_READS; i++) {
44 U val = i < max_reads ? vals[i] : Op::init;
45 total_val =
op(
static_cast<U
>(val), total_val);
60template <
typename T,
typename U,
typename Op,
int N_READS = REDUCE_N_READS>
62 const device T* in [[buffer(0)]],
64 const device
size_t& in_size [[buffer(2)]],
65 uint gid [[thread_position_in_grid]],
66 uint lid [[thread_position_in_threadgroup]],
67 uint grid_size [[threads_per_grid]],
68 uint simd_per_group [[simdgroups_per_threadgroup]],
69 uint simd_lane_id [[thread_index_in_simdgroup]],
70 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
75 per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
78 total_val =
op.simd_reduce(total_val);
79 if (simd_lane_id == 0) {
80 local_vals[simd_group_id] = total_val;
84 threadgroup_barrier(mem_flags::mem_threadgroup);
85 total_val = lid < simd_per_group ? local_vals[lid] :
op.init;
86 total_val =
op.simd_reduce(total_val);
90 op.atomic_update(out, total_val);
94template <
typename T,
typename U,
typename Op,
int N_READS = REDUCE_N_READS>
96 const device T* in [[buffer(0)]],
97 device U* out [[buffer(1)]],
98 const device
size_t& in_size [[buffer(2)]],
99 uint gid [[thread_position_in_grid]],
100 uint lid [[thread_position_in_threadgroup]],
101 uint grid_size [[threads_per_grid]],
102 uint simd_per_group [[simdgroups_per_threadgroup]],
103 uint simd_lane_id [[thread_index_in_simdgroup]],
104 uint simd_group_id [[simdgroup_index_in_threadgroup]],
105 uint thread_group_id [[threadgroup_position_in_grid]]) {
110 per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
114 for (uint16_t lane_offset =
simd_size / 2; lane_offset > 0;
119 if (simd_lane_id == 0) {
120 local_vals[simd_group_id] = total_val;
122 threadgroup_barrier(mem_flags::mem_threadgroup);
125 total_val = lid < simd_per_group ? local_vals[lid] :
op.init;
126 for (uint16_t lane_offset =
simd_size / 2; lane_offset > 0;
133 out[thread_group_id] = total_val;
Op op
Definition binary.h:141
METAL_FUNC U per_thread_all_reduce(const device T *in, const device size_t &in_size, uint gid, uint grid_size)
Definition reduce_all.h:8
void all_reduce_no_atomics(const device T *in, device U *out, const device size_t &in_size, uint gid, uint lid, uint grid_size, uint simd_per_group, uint simd_lane_id, uint simd_group_id, uint thread_group_id)
Definition reduce_all.h:95
void all_reduce(const device T *in, device mlx_atomic< U > *out, const device size_t &in_size, uint gid, uint lid, uint grid_size, uint simd_per_group, uint simd_lane_id, uint simd_group_id)
Definition reduce_all.h:61