MLX
Loading...
Searching...
No Matches
reduce_col.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
4// Small column reduce kernel
6
7template <typename T, typename U, typename Op>
8[[kernel]] void col_reduce_small(
9 const device T* in [[buffer(0)]],
10 device U* out [[buffer(1)]],
11 const constant size_t& reduction_size [[buffer(2)]],
12 const constant size_t& reduction_stride [[buffer(3)]],
13 const constant size_t& out_size [[buffer(4)]],
14 const constant int* shape [[buffer(5)]],
15 const constant size_t* strides [[buffer(6)]],
16 const constant int& ndim [[buffer(7)]],
17 const constant size_t& non_col_reductions [[buffer(8)]],
18 const constant int* non_col_shapes [[buffer(9)]],
19 const constant size_t* non_col_strides [[buffer(10)]],
20 const constant int& non_col_ndim [[buffer(11)]],
21 uint tid [[thread_position_in_grid]]) {
22 // Appease the compiler
23 (void)out_size;
24
25 Op op;
26 U total_val = Op::init;
27
28 auto out_idx = tid;
29
30 in += elem_to_loc(
31 out_idx,
32 shape + non_col_ndim,
33 strides + non_col_ndim,
34 ndim - non_col_ndim);
35
36 for (uint i = 0; i < non_col_reductions; i++) {
37 size_t in_idx =
38 elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim);
39
40 for (uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) {
41 U val = static_cast<U>(in[in_idx]);
42 total_val = op(total_val, val);
43 }
44 }
45
46 out[out_idx] = total_val;
47}
48
50// Column reduce helper
52
53template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
55 const device T* in,
56 threadgroup U* local_data,
57 uint in_idx,
58 uint reduction_size,
59 uint reduction_stride,
60 uint2 tid,
61 uint2 lid,
62 uint2 lsize) {
63 Op op;
64 U total_val = Op::init;
65
66 uint base_offset = (tid.y * lsize.y + lid.y) * N_READS;
67 for (uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) {
68 uint offset = base_offset + r;
69 total_val =
70 op(static_cast<U>(total_val), in[in_idx + offset * reduction_stride]);
71 }
72 local_data[lsize.y * lid.x + lid.y] = total_val;
73 threadgroup_barrier(mem_flags::mem_threadgroup);
74
75 U val = Op::init;
76 if (lid.y == 0) {
77 // Perform reduction across columns in thread group
78 for (uint i = 0; i < lsize.y; i++) {
79 val = op(val, local_data[lsize.y * lid.x + i]);
80 }
81 }
82
83 return val;
84}
85
87// Column reduce kernel
89
90template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
91[[kernel]] void col_reduce_general(
92 const device T* in [[buffer(0)]],
93 device mlx_atomic<U>* out [[buffer(1)]],
94 const constant size_t& reduction_size [[buffer(2)]],
95 const constant size_t& reduction_stride [[buffer(3)]],
96 const constant size_t& out_size [[buffer(4)]],
97 const constant int* shape [[buffer(5)]],
98 const constant size_t* strides [[buffer(6)]],
99 const constant int& ndim [[buffer(7)]],
100 threadgroup U* local_data [[threadgroup(0)]],
101 uint3 tid [[threadgroup_position_in_grid]],
102 uint3 lid [[thread_position_in_threadgroup]],
103 uint3 lsize [[threads_per_threadgroup]]) {
104 auto out_idx = tid.x * lsize.x + lid.x;
105 auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim);
106
107 Op op;
108 if (out_idx < out_size) {
109 U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
110 in,
111 local_data,
112 in_idx,
113 reduction_size,
114 reduction_stride,
115 tid.xy,
116 lid.xy,
117 lsize.xy);
118
119 // Write out reduction results generated by threadgroups working on specific
120 // output element, contiguously.
121 if (lid.y == 0) {
122 op.atomic_update(out, val, out_idx);
123 }
124 }
125}
126
127template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
129 const device T* in [[buffer(0)]],
130 device U* out [[buffer(1)]],
131 const constant size_t& reduction_size [[buffer(2)]],
132 const constant size_t& reduction_stride [[buffer(3)]],
133 const constant size_t& out_size [[buffer(4)]],
134 const constant int* shape [[buffer(5)]],
135 const constant size_t* strides [[buffer(6)]],
136 const constant int& ndim [[buffer(7)]],
137 threadgroup U* local_data [[threadgroup(0)]],
138 uint3 tid [[threadgroup_position_in_grid]],
139 uint3 lid [[thread_position_in_threadgroup]],
140 uint3 gid [[thread_position_in_grid]],
141 uint3 lsize [[threads_per_threadgroup]],
142 uint3 gsize [[threads_per_grid]]) {
143 auto out_idx = tid.x * lsize.x + lid.x;
144 auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim);
145
146 if (out_idx < out_size) {
147 U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
148 in,
149 local_data,
150 in_idx,
151 reduction_size,
152 reduction_stride,
153 tid.xy,
154 lid.xy,
155 lsize.xy);
156
157 // Write out reduction results generated by threadgroups working on specific
158 // output element, contiguously.
159 if (lid.y == 0) {
160 uint tgsize_y = ceildiv(gsize.y, lsize.y);
161 uint tgsize_z = ceildiv(gsize.z, lsize.z);
162 out[tgsize_y * tgsize_z * gid.x + tgsize_y * tid.z + tid.y] = val;
163 }
164 }
165}
METAL_FUNC stride_t elem_to_loc(uint elem, device const int *shape, device const stride_t *strides, int ndim)
Definition utils.h:77
size_t ceildiv(size_t N, size_t M)
Compute ceil((float)N/(float)M)
Definition utils.h:296
Op op
Definition binary.h:141
void col_reduce_general(const device T *in, device mlx_atomic< U > *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant size_t &out_size, const constant int *shape, const constant size_t *strides, const constant int &ndim, threadgroup U *local_data, uint3 tid, uint3 lid, uint3 lsize)
Definition reduce_col.h:91
METAL_FUNC U _contiguous_strided_reduce(const device T *in, threadgroup U *local_data, uint in_idx, uint reduction_size, uint reduction_stride, uint2 tid, uint2 lid, uint2 lsize)
Definition reduce_col.h:54
void col_reduce_small(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant size_t &out_size, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant size_t &non_col_reductions, const constant int *non_col_shapes, const constant size_t *non_col_strides, const constant int &non_col_ndim, uint tid)
Definition reduce_col.h:8
void col_reduce_general_no_atomics(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant size_t &out_size, const constant int *shape, const constant size_t *strides, const constant int &ndim, threadgroup U *local_data, uint3 tid, uint3 lid, uint3 gid, uint3 lsize, uint3 gsize)
Definition reduce_col.h:128
Definition atomic.h:25