MLX
Loading...
Searching...
No Matches
loader_channel_n.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
6
8
10// Loading helper
12
13namespace mlx {
14namespace steel {
15
16template <short n_channels_>
18 STEEL_CONST short n_channels = n_channels_;
19 STEEL_CONST short vec_size = n_channels_ <= 4 ? 4 : 8;
20 STEEL_CONST short excess = vec_size - n_channels_;
21};
22
23template <>
24struct ChannelHelper<1> {
28};
29
30template <>
31struct ChannelHelper<2> {
35};
36
37template <>
38struct ChannelHelper<3> {
42};
43
44template <>
45struct ChannelHelper<4> {
49};
50
51template <
52 typename T,
53 short BM,
54 short BN,
55 short BK,
56 short tgp_size,
57 short n_channels,
58 short tgp_padding = 0>
60 // Destination dimensions
61 STEEL_CONST short BROWS = BM;
62 STEEL_CONST short BCOLS = BK;
63
64 // Read dimensions
65 STEEL_CONST short dst_ld = BCOLS + tgp_padding;
67
68 // Thread read shape
70 STEEL_CONST short TROWS = tgp_size / TCOLS;
71
72 // Rows / strided reads within the block
74
75 // Thread location indices
76 const short thread_idx;
77 const short bi;
78 const short bj;
79
80 // threadgroup and device memory
81 threadgroup T* dst;
82
83 const constant MLXConvParams<2>* params;
85
86 short weight_hw;
87
88 const device T* src[n_rows];
89
93
94 /* Constructor */
96 const device T* src_,
97 threadgroup T* dst_,
98 const int2 offsets,
99 const constant MLXConvParams<2>* params_,
100 const constant ImplicitGemmConv2DParams* gemm_params_,
101 uint simd_group_id [[simdgroup_index_in_threadgroup]],
102 uint simd_lane_id [[thread_index_in_simdgroup]])
103 : thread_idx(simd_group_id * 32 + simd_lane_id),
106 dst(dst_ + bi * dst_ld + bj),
107 params(params_),
108 gemm_params(gemm_params_),
110 int out_n_pixels = params->oS[0] * params->oS[1];
111
113 for (short i = 0; i < n_rows; ++i) {
114 int offset_nhw = offsets.y + bi + i * TROWS;
115 int n = offset_nhw / out_n_pixels;
116 int hw = offset_nhw % out_n_pixels;
117 int oh = hw / params->oS[1];
118 int ow = hw % params->oS[1];
119
120 int ih = oh * params->str[0] - params->pad[0];
121 int iw = ow * params->str[1] - params->pad[1];
122
123 // Read from input if in bounds
124 src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] +
125 iw * params->in_strides[2];
126
127 read_n[i] = n;
128 read_ih[i] = ih;
129 read_iw[i] = iw;
130 }
131 }
132
133 /* Load from device memory into threadgroup memory - without bound checking */
134 METAL_FUNC void load_unsafe() const {
135 if (weight_hw >= params->wS[1] * params->wS[0]) {
137 for (short i = 0; i < BROWS; i += TROWS) {
139 for (short j = 0; j < vec_size; j++) {
140 dst[i * dst_ld + j] = T(0);
141 }
142 }
143 return;
144 }
145
146 int wh = (weight_hw / params->wS[1]);
147 int ww = (weight_hw % params->wS[1]);
148
149 int flip_h = params->flip ? params->wS[0] - wh - 1 : wh;
150 int flip_w = params->flip ? params->wS[1] - ww - 1 : ww;
151
152 int weight_h = flip_h * params->kdil[0];
153 int weight_w = flip_w * params->kdil[1];
154
156 for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
157 // Find bounds
158 int n = read_n[i];
159 int ih = read_ih[i] + weight_h;
160 int iw = read_iw[i] + weight_w;
161
162 // Read from input if in bounds
163 if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) &&
164 (iw >= 0 && iw < params->iS[1])) {
165 const device T* curr_src = src[i] + weight_h * params->in_strides[1] +
166 weight_w * params->in_strides[2];
167
169 for (short j = 0; j < n_channels; ++j) {
170 dst[is * dst_ld + j] = curr_src[j];
171 }
172
174 for (short j = n_channels; j < vec_size; ++j) {
175 dst[is * dst_ld + j] = T(0);
176 }
177 }
178
179 // Zero pad otherwise
180 else {
182 for (short j = 0; j < vec_size; ++j) {
183 dst[is * dst_ld + j] = T(0);
184 }
185 }
186 }
187 }
188
189 /* Iteration helper */
190 METAL_FUNC void next() {
191 weight_hw += TCOLS;
192 }
193};
194
195template <
196 typename T,
197 short BM,
198 short BN,
199 short BK,
200 short tgp_size,
201 short n_channels,
202 short tgp_padding = 0>
204 // Destination dimensions
205 STEEL_CONST short BROWS = BN;
206 STEEL_CONST short BCOLS = BK;
207
208 // Read dimensions
209 STEEL_CONST short dst_ld = BCOLS + tgp_padding;
211
212 // Thread read shape
214 STEEL_CONST short TROWS = tgp_size / TCOLS;
215
216 // Rows / strided reads within the block
218
219 // Leading dimension for src
220 const int src_ld;
221
222 // Thread location indices
223 const short thread_idx;
224 const short bi;
225 const short bj;
226
227 // threadgroup and device memory
228 threadgroup T* dst;
229 const device T* src;
230
231 const constant MLXConvParams<2>* params;
232
234
235 const int read_n;
236 const bool do_read;
237
238 /* Constructor */
240 const device T* src_,
241 threadgroup T* dst_,
242 const int2 offsets,
243 const constant MLXConvParams<2>* params_,
244 const constant ImplicitGemmConv2DParams* gemm_params_,
245 uint simd_group_id [[simdgroup_index_in_threadgroup]],
246 uint simd_lane_id [[thread_index_in_simdgroup]])
247 : src_ld(params_->wt_strides[0]),
248 thread_idx(simd_group_id * 32 + simd_lane_id),
251 dst(dst_ + bi * dst_ld + bj),
252 src(src_ + bi * src_ld),
253 params(params_),
255 read_n(offsets.y + bi),
256 do_read(read_n + BN <= gemm_params_->N) {}
257
258 /* Load from device memory into threadgroup memory - without bound checking */
259 METAL_FUNC void load_unsafe() const {
260 if (bi >= BROWS || bj >= BCOLS)
261 return;
262
263 if (read_n >= params->O || weight_hw >= params->wS[1] * params->wS[0]) {
265 for (short i = 0; i < BROWS; i += TROWS) {
267 for (short j = 0; j < vec_size; j++) {
268 dst[i * dst_ld + j] = T(0);
269 }
270 }
271
272 return;
273 }
274
275 const device T* curr_src = src + weight_hw * params->wt_strides[2];
276
277 if (BN != 8 || do_read) {
279 for (short i = 0; i < BROWS; i += TROWS) {
281 for (short j = 0; j < n_channels; j++) {
282 dst[i * dst_ld + j] = curr_src[i * src_ld + j];
283 }
284
286 for (short j = n_channels; j < vec_size; j++) {
287 dst[i * dst_ld + j] = T(0);
288 }
289 }
290 } else {
291 for (short i = 0; i < BROWS; i += TROWS) {
292 if (((read_n + i) < params->O)) {
294 for (short j = 0; j < n_channels; j++) {
295 dst[i * dst_ld + j] = curr_src[i * src_ld + j];
296 }
297
299 for (short j = n_channels; j < vec_size; j++) {
300 dst[i * dst_ld + j] = T(0);
301 }
302 } else {
304 for (short j = 0; j < vec_size; j++) {
305 dst[i * dst_ld + j] = T(0);
306 }
307 }
308 }
309 }
310 }
311
312 /* Iteration helper */
313 METAL_FUNC void next() {
314 weight_hw += TCOLS;
315 }
316};
317
318} // namespace steel
319} // namespace mlx
Definition allocator.h:7
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
#define STEEL_CONST
Definition defines.h:3
Definition params.h:6
const int oS[NDIM]
Definition params.h:12
const int kdil[NDIM]
Definition params.h:15
const int str[NDIM]
Definition params.h:13
const size_t wt_strides[NDIM+2]
Definition params.h:18
const bool flip
Definition params.h:21
const size_t in_strides[NDIM+2]
Definition params.h:17
const int wS[NDIM]
Definition params.h:11
const int O
Definition params.h:9
const int pad[NDIM]
Definition params.h:14
Definition loader_channel_n.h:17
STEEL_CONST short vec_size
Definition loader_channel_n.h:19
STEEL_CONST short n_channels
Definition loader_channel_n.h:18
STEEL_CONST short excess
Definition loader_channel_n.h:20
Definition loader_channel_n.h:59
const constant MLXConvParams< 2 > * params
Definition loader_channel_n.h:83
STEEL_CONST short BROWS
Definition loader_channel_n.h:61
threadgroup T * dst
Definition loader_channel_n.h:81
int read_ih[n_rows]
Definition loader_channel_n.h:91
STEEL_CONST short vec_size
Definition loader_channel_n.h:66
STEEL_CONST short TROWS
Definition loader_channel_n.h:70
const short bj
Definition loader_channel_n.h:78
short weight_hw
Definition loader_channel_n.h:86
STEEL_CONST short n_rows
Definition loader_channel_n.h:73
STEEL_CONST short BCOLS
Definition loader_channel_n.h:62
const short thread_idx
Definition loader_channel_n.h:76
const short bi
Definition loader_channel_n.h:77
METAL_FUNC void load_unsafe() const
Definition loader_channel_n.h:134
int read_iw[n_rows]
Definition loader_channel_n.h:92
METAL_FUNC Conv2DInputBlockLoaderSmallChannels(const device T *src_, threadgroup T *dst_, const int2 offsets, const constant MLXConvParams< 2 > *params_, const constant ImplicitGemmConv2DParams *gemm_params_, uint simd_group_id, uint simd_lane_id)
Definition loader_channel_n.h:95
STEEL_CONST short TCOLS
Definition loader_channel_n.h:69
int read_n[n_rows]
Definition loader_channel_n.h:90
STEEL_CONST short dst_ld
Definition loader_channel_n.h:65
const constant ImplicitGemmConv2DParams * gemm_params
Definition loader_channel_n.h:84
METAL_FUNC void next()
Definition loader_channel_n.h:190
const device T * src[n_rows]
Definition loader_channel_n.h:88
Definition loader_channel_n.h:203
STEEL_CONST short vec_size
Definition loader_channel_n.h:210
METAL_FUNC void load_unsafe() const
Definition loader_channel_n.h:259
threadgroup T * dst
Definition loader_channel_n.h:228
METAL_FUNC void next()
Definition loader_channel_n.h:313
int weight_hw
Definition loader_channel_n.h:233
STEEL_CONST short TROWS
Definition loader_channel_n.h:214
const bool do_read
Definition loader_channel_n.h:236
const device T * src
Definition loader_channel_n.h:229
STEEL_CONST short BCOLS
Definition loader_channel_n.h:206
const int read_n
Definition loader_channel_n.h:235
const int src_ld
Definition loader_channel_n.h:220
STEEL_CONST short dst_ld
Definition loader_channel_n.h:209
const short thread_idx
Definition loader_channel_n.h:223
STEEL_CONST short BROWS
Definition loader_channel_n.h:205
STEEL_CONST short TCOLS
Definition loader_channel_n.h:213
const short bj
Definition loader_channel_n.h:225
METAL_FUNC Conv2DWeightBlockLoaderSmallChannels(const device T *src_, threadgroup T *dst_, const int2 offsets, const constant MLXConvParams< 2 > *params_, const constant ImplicitGemmConv2DParams *gemm_params_, uint simd_group_id, uint simd_lane_id)
Definition loader_channel_n.h:239
const short bi
Definition loader_channel_n.h:224
STEEL_CONST short n_rows
Definition loader_channel_n.h:217
const constant MLXConvParams< 2 > * params
Definition loader_channel_n.h:231