8template <
typename T,
typename U,
typename Op>
10 const device T* in [[buffer(0)]],
11 device U* out [[buffer(1)]],
12 const constant
size_t& reduction_size [[buffer(2)]],
13 const constant
size_t& out_size [[buffer(3)]],
14 const constant
size_t& non_row_reductions [[buffer(4)]],
15 const constant
int* shape [[buffer(5)]],
16 const constant
size_t* strides [[buffer(6)]],
17 const constant
int& ndim [[buffer(7)]],
18 uint lid [[thread_position_in_grid]]) {
23 if (out_idx >= out_size) {
27 U total_val = Op::init;
29 for (
short r = 0; r < short(non_row_reductions); r++) {
30 uint in_idx =
elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
31 const device T* in_row = in + in_idx;
33 for (
short i = 0; i < short(reduction_size); i++) {
34 total_val =
op(
static_cast<U
>(in_row[i]), total_val);
38 out[out_idx] = total_val;
42template <
typename T,
typename U,
typename Op>
44 const device T* in [[buffer(0)]],
45 device U* out [[buffer(1)]],
46 const constant
size_t& reduction_size [[buffer(2)]],
47 const constant
size_t& out_size [[buffer(3)]],
48 const constant
size_t& non_row_reductions [[buffer(4)]],
49 const constant
int* shape [[buffer(5)]],
50 const constant
size_t* strides [[buffer(6)]],
51 const constant
int& ndim [[buffer(7)]],
52 uint tid [[threadgroup_position_in_grid]],
53 uint simd_lane_id [[thread_index_in_simdgroup]],
54 uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
55 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
58 uint out_idx = simd_per_group * tid + simd_group_id;
60 if (out_idx >= out_size) {
64 U total_val = Op::init;
66 if (
short(non_row_reductions) == 1) {
67 uint in_idx =
elem_to_loc(out_idx, shape, strides, ndim);
68 const device T* in_row = in + in_idx;
70 for (
short i = simd_lane_id; i < short(reduction_size); i += 32) {
71 total_val =
op(
static_cast<U
>(in_row[i]), total_val);
75 else if (
short(non_row_reductions) >= 32) {
76 for (
short r = simd_lane_id; r < short(non_row_reductions); r += 32) {
77 uint in_idx =
elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
78 const device T* in_row = in + in_idx;
80 for (
short i = 0; i < short(reduction_size); i++) {
81 total_val =
op(
static_cast<U
>(in_row[i]), total_val);
88 const short n_reductions =
89 short(reduction_size) * short(non_row_reductions);
90 const short reductions_per_thread =
93 const short r_st = simd_lane_id / reductions_per_thread;
94 const short r_ed = short(non_row_reductions);
95 const short r_jump =
simd_size / reductions_per_thread;
97 const short i_st = simd_lane_id % reductions_per_thread;
98 const short i_ed = short(reduction_size);
99 const short i_jump = reductions_per_thread;
102 for (
short r = r_st; r < r_ed; r += r_jump) {
103 uint in_idx =
elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
104 const device T* in_row = in + in_idx;
106 for (
short i = i_st; i < i_ed; i += i_jump) {
107 total_val =
op(
static_cast<U
>(in_row[i]), total_val);
113 total_val =
op.simd_reduce(total_val);
115 if (simd_lane_id == 0) {
116 out[out_idx] = total_val;
124template <
typename T,
typename U,
typename Op,
int N_READS = REDUCE_N_READS>
127 const constant
size_t& reduction_size,
128 const constant
size_t& out_size,
129 const constant
int* shape,
130 const constant
size_t* strides,
131 const constant
int& ndim,
139 int idx = tid.y * out_size + tid.x;
140 int extra_offset =
elem_to_loc(idx, shape, strides, ndim);
141 in += extra_offset + lid_x * N_READS;
144 U total_val = Op::init;
148 for (; r < (int)
ceildiv(reduction_size, N_READS * lsize_x) - 1; r++) {
150 for (
int i = 0; i < N_READS; i++) {
153 for (
int i = 0; i < N_READS; i++) {
154 total_val =
op(
static_cast<U
>(vals[i]), total_val);
157 in += lsize_x * N_READS;
161 size_t reduction_index = (lid_x + (size_t)lsize_x * r) * N_READS;
162 if (reduction_index < reduction_size) {
163 int max_reads = reduction_size - reduction_index;
166 for (
int i = 0; i < N_READS; i++) {
167 int idx = min(i, max_reads - 1);
168 vals[i] =
static_cast<U
>(in[idx]);
170 for (
int i = 0; i < N_READS; i++) {
171 T val = i < max_reads ? vals[i] : Op::init;
172 total_val =
op(
static_cast<U
>(val), total_val);
179template <
typename T,
typename U,
typename Op,
int N_READS = REDUCE_N_READS>
181 const device T* in [[buffer(0)]],
183 const constant
size_t& reduction_size [[buffer(2)]],
184 const constant
size_t& out_size [[buffer(3)]],
185 const constant
size_t& non_row_reductions [[buffer(4)]],
186 const constant
int* shape [[buffer(5)]],
187 const constant
size_t* strides [[buffer(6)]],
188 const constant
int& ndim [[buffer(7)]],
189 uint3 lid [[thread_position_in_threadgroup]],
190 uint3 lsize [[threads_per_threadgroup]],
191 uint3 tid [[threadgroup_position_in_grid]],
192 uint simd_lane_id [[thread_index_in_simdgroup]],
193 uint simd_per_group [[simdgroups_per_threadgroup]],
194 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
195 (void)non_row_reductions;
200 U total_val = per_thread_row_reduce<T, U, Op, N_READS>(
211 total_val =
op.simd_reduce(total_val);
214 if (simd_lane_id == 0) {
215 local_vals[simd_group_id] = total_val;
217 threadgroup_barrier(mem_flags::mem_threadgroup);
222 total_val = lid.x < simd_per_group ? local_vals[lid.x] :
op.init;
223 total_val =
op.simd_reduce(total_val);
227 op.atomic_update(out, total_val, tid.x);
231template <
typename T,
typename U,
typename Op,
int N_READS = REDUCE_N_READS>
233 const device T* in [[buffer(0)]],
234 device U* out [[buffer(1)]],
235 const constant
size_t& reduction_size [[buffer(2)]],
236 const constant
size_t& out_size [[buffer(3)]],
237 const constant
size_t& non_row_reductions [[buffer(4)]],
238 const constant
int* shape [[buffer(5)]],
239 const constant
size_t* strides [[buffer(6)]],
240 const constant
int& ndim [[buffer(7)]],
241 uint3 lid [[thread_position_in_threadgroup]],
242 uint3 lsize [[threads_per_threadgroup]],
243 uint3 gsize [[threads_per_grid]],
244 uint3 tid [[threadgroup_position_in_grid]],
245 uint simd_lane_id [[thread_index_in_simdgroup]],
246 uint simd_per_group [[simdgroups_per_threadgroup]],
247 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
248 (void)non_row_reductions;
253 U total_val = per_thread_row_reduce<T, U, Op, N_READS>(
265 for (uint16_t i =
simd_size / 2; i > 0; i /= 2) {
270 if (simd_lane_id == 0) {
271 local_vals[simd_group_id] = total_val;
273 threadgroup_barrier(mem_flags::mem_threadgroup);
278 total_val = lid.x < simd_per_group ? local_vals[lid.x] :
op.init;
279 for (uint16_t i =
simd_size / 2; i > 0; i /= 2) {
285 out[(
ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val;
Op op
Definition binary.h:141
void row_reduce_general_no_atomics(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &out_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, uint3 lid, uint3 lsize, uint3 gsize, uint3 tid, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
Definition reduce_row.h:232
METAL_FUNC U per_thread_row_reduce(const device T *in, const constant size_t &reduction_size, const constant size_t &out_size, const constant int *shape, const constant size_t *strides, const constant int &ndim, uint lsize_x, uint lid_x, uint2 tid)
Definition reduce_row.h:125
void row_reduce_general(const device T *in, device mlx_atomic< U > *out, const constant size_t &reduction_size, const constant size_t &out_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, uint3 lid, uint3 lsize, uint3 tid, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
Definition reduce_row.h:180
void row_reduce_general_med(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &out_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, uint tid, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
Definition reduce_row.h:43
void row_reduce_general_small(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &out_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, uint lid)
Definition reduce_row.h:9