7template <
typename T,
typename U,
typename Op>
9 const device T* in [[buffer(0)]],
10 device U* out [[buffer(1)]],
11 const constant
size_t& reduction_size [[buffer(2)]],
12 const constant
size_t& reduction_stride [[buffer(3)]],
13 const constant
size_t& out_size [[buffer(4)]],
14 const constant
int* shape [[buffer(5)]],
15 const constant
size_t* strides [[buffer(6)]],
16 const constant
int& ndim [[buffer(7)]],
17 const constant
size_t& non_col_reductions [[buffer(8)]],
18 const constant
int* non_col_shapes [[buffer(9)]],
19 const constant
size_t* non_col_strides [[buffer(10)]],
20 const constant
int& non_col_ndim [[buffer(11)]],
21 uint tid [[thread_position_in_grid]]) {
26 U total_val = Op::init;
33 strides + non_col_ndim,
36 for (uint i = 0; i < non_col_reductions; i++) {
38 elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim);
40 for (uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) {
41 U val =
static_cast<U
>(in[in_idx]);
42 total_val =
op(total_val, val);
46 out[out_idx] = total_val;
53template <
typename T,
typename U,
typename Op,
int N_READS = REDUCE_N_READS>
56 threadgroup U* local_data,
59 uint reduction_stride,
64 U total_val = Op::init;
66 uint base_offset = (tid.y * lsize.y + lid.y) * N_READS;
67 for (uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) {
68 uint offset = base_offset + r;
70 op(
static_cast<U
>(total_val), in[in_idx + offset * reduction_stride]);
72 local_data[lsize.y * lid.x + lid.y] = total_val;
73 threadgroup_barrier(mem_flags::mem_threadgroup);
78 for (uint i = 0; i < lsize.y; i++) {
79 val =
op(val, local_data[lsize.y * lid.x + i]);
90template <
typename T,
typename U,
typename Op,
int N_READS = REDUCE_N_READS>
92 const device T* in [[buffer(0)]],
94 const constant
size_t& reduction_size [[buffer(2)]],
95 const constant
size_t& reduction_stride [[buffer(3)]],
96 const constant
size_t& out_size [[buffer(4)]],
97 const constant
int* shape [[buffer(5)]],
98 const constant
size_t* strides [[buffer(6)]],
99 const constant
int& ndim [[buffer(7)]],
100 threadgroup U* local_data [[threadgroup(0)]],
101 uint3 tid [[threadgroup_position_in_grid]],
102 uint3 lid [[thread_position_in_threadgroup]],
103 uint3 lsize [[threads_per_threadgroup]]) {
104 auto out_idx = tid.x * lsize.x + lid.x;
105 auto in_idx =
elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim);
108 if (out_idx < out_size) {
109 U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
122 op.atomic_update(out, val, out_idx);
127template <
typename T,
typename U,
typename Op,
int N_READS = REDUCE_N_READS>
129 const device T* in [[buffer(0)]],
130 device U* out [[buffer(1)]],
131 const constant
size_t& reduction_size [[buffer(2)]],
132 const constant
size_t& reduction_stride [[buffer(3)]],
133 const constant
size_t& out_size [[buffer(4)]],
134 const constant
int* shape [[buffer(5)]],
135 const constant
size_t* strides [[buffer(6)]],
136 const constant
int& ndim [[buffer(7)]],
137 threadgroup U* local_data [[threadgroup(0)]],
138 uint3 tid [[threadgroup_position_in_grid]],
139 uint3 lid [[thread_position_in_threadgroup]],
140 uint3 gid [[thread_position_in_grid]],
141 uint3 lsize [[threads_per_threadgroup]],
142 uint3 gsize [[threads_per_grid]]) {
143 auto out_idx = tid.x * lsize.x + lid.x;
144 auto in_idx =
elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim);
146 if (out_idx < out_size) {
147 U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
160 uint tgsize_y =
ceildiv(gsize.y, lsize.y);
161 uint tgsize_z =
ceildiv(gsize.z, lsize.z);
162 out[tgsize_y * tgsize_z * gid.x + tgsize_y * tid.z + tid.y] = val;
Op op
Definition binary.h:141
void col_reduce_general(const device T *in, device mlx_atomic< U > *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant size_t &out_size, const constant int *shape, const constant size_t *strides, const constant int &ndim, threadgroup U *local_data, uint3 tid, uint3 lid, uint3 lsize)
Definition reduce_col.h:91
METAL_FUNC U _contiguous_strided_reduce(const device T *in, threadgroup U *local_data, uint in_idx, uint reduction_size, uint reduction_stride, uint2 tid, uint2 lid, uint2 lsize)
Definition reduce_col.h:54
void col_reduce_small(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant size_t &out_size, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant size_t &non_col_reductions, const constant int *non_col_shapes, const constant size_t *non_col_strides, const constant int &non_col_ndim, uint tid)
Definition reduce_col.h:8
void col_reduce_general_no_atomics(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant size_t &out_size, const constant int *shape, const constant size_t *strides, const constant int &ndim, threadgroup U *local_data, uint3 tid, uint3 lid, uint3 gid, uint3 lsize, uint3 gsize)
Definition reduce_col.h:128