MLX
Loading...
Searching...
No Matches
loader_channel_l.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
6
8
10// Loading helper
12
13namespace mlx {
14namespace steel {
15
16template <
17 typename T,
18 short BM,
19 short BN,
20 short BK,
21 short tgp_size,
22 short tgp_padding = 0>
24 // Destination dimensions
25 STEEL_CONST short BROWS = BM;
26 STEEL_CONST short BCOLS = BK;
27
28 // Read dimensions
29 STEEL_CONST short dst_ld = BCOLS + tgp_padding;
30 STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4;
31
32 // Thread read shape
34 STEEL_CONST short TROWS = tgp_size / TCOLS;
35
36 // Rows / strided reads within the block
38
39 // Thread location indices
40 const short thread_idx;
41 const short bi;
42 const short bj;
43
44 // threadgroup and device memory
45 threadgroup T* dst;
46
47 const constant MLXConvParams<2>* params;
49
50 short weight_h;
51 short weight_w;
52
53 const device T* src[n_rows];
54
58
59 /* Constructor */
61 const device T* src_,
62 threadgroup T* dst_,
63 const int2 offsets,
64 const constant MLXConvParams<2>* params_,
65 const constant ImplicitGemmConv2DParams* gemm_params_,
66 uint simd_group_id [[simdgroup_index_in_threadgroup]],
67 uint simd_lane_id [[thread_index_in_simdgroup]])
68 : thread_idx(simd_group_id * 32 + simd_lane_id),
71 dst(dst_ + bi * dst_ld + bj),
72 params(params_),
73 gemm_params(gemm_params_),
74 weight_h(0),
75 weight_w(0) {
76 int out_n_pixels = params->oS[0] * params->oS[1];
77
79 for (short i = 0; i < n_rows; ++i) {
80 int offset_nhw = offsets.y + bi + i * TROWS;
81 int n = offset_nhw / out_n_pixels;
82 int hw = offset_nhw % out_n_pixels;
83 int oh = hw / params->oS[1];
84 int ow = hw % params->oS[1];
85
86 int ih = oh * params->str[0] - params->pad[0];
87 int iw = ow * params->str[1] - params->pad[1];
88
89 read_n[i] = n;
90 read_ih[i] = ih;
91 read_iw[i] = iw;
92
93 // Adjust for flip
94 if (params->flip) {
95 ih += (params->wS[0] - 1) * params->kdil[0];
96 iw += (params->wS[1] - 1) * params->kdil[1];
97 }
98
99 // Read from input if in bounds
100 src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] +
101 iw * params->in_strides[2] + bj;
102 }
103 }
104
105 /* Load from device memory into threadgroup memory - without bound checking */
106 METAL_FUNC void load_unsafe() const {
108 for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
109 // Find bounds
110 int n = read_n[i];
111 int ih = read_ih[i] + weight_h * params->kdil[0];
112 int iw = read_iw[i] + weight_w * params->kdil[1];
113
114 // Read from input if in bounds
115 if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) &&
116 (iw >= 0 && iw < params->iS[1])) {
118 for (short j = 0; j < vec_size; ++j) {
119 dst[is * dst_ld + j] = src[i][j];
120 }
121 }
122
123 // Zero pad otherwise
124 else {
126 for (short j = 0; j < vec_size; ++j) {
127 dst[is * dst_ld + j] = T(0);
128 }
129 }
130 }
131 }
132
133 /* Iteration helper */
134 METAL_FUNC void next() {
135 if (++weight_w < params->wS[1]) {
137 for (short i = 0; i < n_rows; i++) {
139 }
140
141 return;
142 }
143
144 weight_w = 0;
145
146 if (++weight_h < params->wS[0]) {
148 for (short i = 0; i < n_rows; i++) {
150 }
151
152 return;
153 }
154
155 weight_h = 0;
156
158 for (short i = 0; i < n_rows; i++) {
160 }
161 }
162};
163
164template <
165 typename T,
166 short BM,
167 short BN,
168 short BK,
169 short tgp_size,
170 short tgp_padding = 0>
172 // Destination dimensions
173 STEEL_CONST short BROWS = BM;
174 STEEL_CONST short BCOLS = BK;
175
176 // Read dimensions
177 STEEL_CONST short dst_ld = BCOLS + tgp_padding;
178 STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4;
179
180 // Thread read shape
182 STEEL_CONST short TROWS = tgp_size / TCOLS;
183
184 // Rows / strided reads within the block
186
187 using mask_t = short;
188
189 // Thread location indices
190 const short thread_idx;
191 const short bi;
192 const short bj;
193
194 // threadgroup and device memory
195 threadgroup T* dst;
196
197 const constant MLXConvParams<2>* params;
199
200 short weight_h;
201 short weight_w;
202
203 const device T* src[n_rows];
204
207
208 /* Constructor */
210 const device T* src_,
211 threadgroup T* dst_,
212 const int2 offsets,
213 const constant MLXConvParams<2>* params_,
214 const constant ImplicitGemmConv2DParams* gemm_params_,
215 uint simd_group_id [[simdgroup_index_in_threadgroup]],
216 uint simd_lane_id [[thread_index_in_simdgroup]])
217 : thread_idx(simd_group_id * 32 + simd_lane_id),
220 dst(dst_ + bi * dst_ld + bj),
221 params(params_),
222 gemm_params(gemm_params_),
223 weight_h(0),
224 weight_w(0) {
225 int out_n_pixels = params->oS[0] * params->oS[1];
226
227 int read_n[n_rows];
228 int read_ih[n_rows];
229 int read_iw[n_rows];
230
232 for (short i = 0; i < n_rows; ++i) {
233 int offset_nhw = offsets.y + bi + i * TROWS;
234 int n = offset_nhw / out_n_pixels;
235 int hw = offset_nhw % out_n_pixels;
236 int oh = hw / params->oS[1];
237 int ow = hw % params->oS[1];
238
239 int ih = oh * params->str[0] - params->pad[0];
240 int iw = ow * params->str[1] - params->pad[1];
241
242 read_n[i] = n;
243 read_ih[i] = ih;
244 read_iw[i] = iw;
245
246 // Adjust for flip
247 if (params->flip) {
248 ih += (params->wS[0] - 1) * params->kdil[0];
249 iw += (params->wS[1] - 1) * params->kdil[1];
250 }
251
252 // Read from input if in bounds
253 src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] +
254 iw * params->in_strides[2] + bj;
255 }
256
258 for (short i = 0; i < n_rows; ++i) {
259 mask_h[i] = 0;
260 mask_w[i] = 0;
261 }
262
263 for (short kh = 0; kh < params->wS[0]; kh++) {
264 short flip_h = params->flip ? params->wS[0] - kh - 1 : kh;
266 for (short i = 0; i < n_rows; ++i) {
267 int n = read_n[i];
268 int ih = read_ih[i] + flip_h * params->kdil[0];
269
270 bool in_bounds = n < params->N && ih >= 0 && ih < params->iS[0];
271
272 mask_h[i] |= (in_bounds << kh);
273 }
274 }
275
276 for (short kw = 0; kw < params->wS[1]; kw++) {
277 short flip_w = params->flip ? params->wS[1] - kw - 1 : kw;
279 for (short i = 0; i < n_rows; ++i) {
280 int iw = read_iw[i] + flip_w * params->kdil[1];
281
282 bool in_bounds = iw >= 0 && iw < params->iS[1];
283
284 mask_w[i] |= (in_bounds << kw);
285 }
286 }
287 }
288
289 /* Load from device memory into threadgroup memory - without bound checking */
290 METAL_FUNC void load_unsafe() const {
291 mask_t h_mask = mask_t(1) << weight_h;
292 mask_t w_mask = mask_t(1) << weight_w;
293
295 for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
296 // Read from input if in bounds
297 if ((mask_h[i] & h_mask) && (mask_w[i] & w_mask)) {
299 for (short j = 0; j < vec_size; ++j) {
300 dst[is * dst_ld + j] = src[i][j];
301 }
302 }
303
304 // Zero pad otherwise
305 else {
307 for (short j = 0; j < vec_size; ++j) {
308 dst[is * dst_ld + j] = T(0);
309 }
310 }
311 }
312 }
313
314 /* Iteration helper */
315 METAL_FUNC void next() {
316 if (++weight_w < params->wS[1]) {
318 for (short i = 0; i < n_rows; i++) {
320 }
321
322 return;
323 }
324
325 weight_w = 0;
326
327 if (++weight_h < params->wS[0]) {
329 for (short i = 0; i < n_rows; i++) {
331 }
332
333 return;
334 }
335
336 weight_h = 0;
337
339 for (short i = 0; i < n_rows; i++) {
341 }
342 }
343};
344
345template <
346 typename T,
347 short BM,
348 short BN,
349 short BK,
350 short tgp_size,
351 short tgp_padding = 0>
353 // Destination dimensions
354 STEEL_CONST short BROWS = BN;
355 STEEL_CONST short BCOLS = BK;
356
357 // Read dimensions
358 STEEL_CONST short dst_ld = BCOLS + tgp_padding;
360 (BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4);
361
362 // Thread read shape
364 STEEL_CONST short TROWS = tgp_size / TCOLS;
365
366 // Rows / strided reads within the block
368
369 // Leading dimension for src
370 const int src_ld;
371
372 // Thread location indices
373 const short thread_idx;
374 const short bi;
375 const short bj;
376
377 // threadgroup and device memory
378 threadgroup T* dst;
379 const device T* src;
380
381 const constant MLXConvParams<2>* params;
382
384
385 const int read_n;
386 const bool do_read;
387
388 /* Constructor */
390 const device T* src_,
391 threadgroup T* dst_,
392 const int2 offsets,
393 const constant MLXConvParams<2>* params_,
394 const constant ImplicitGemmConv2DParams* gemm_params_,
395 uint simd_group_id [[simdgroup_index_in_threadgroup]],
396 uint simd_lane_id [[thread_index_in_simdgroup]])
397 : src_ld(params_ -> wt_strides[0]),
398 thread_idx(simd_group_id * 32 + simd_lane_id),
401 dst(dst_ + bi * dst_ld + bj),
402 src(src_ + bi * src_ld + bj),
403 params(params_),
404 weight_hw(0),
405 read_n(offsets.y + bi),
406 do_read(read_n + n_rows * TROWS <= gemm_params_->N) {}
407
408 /* Load from device memory into threadgroup memory - without bound checking */
409 METAL_FUNC void load_unsafe() const {
410 if (BN != 8 || do_read) {
412 for (short i = 0; i < BN; i += TROWS) {
414 for (short j = 0; j < vec_size; j++) {
415 dst[i * dst_ld + j] = src[i * src_ld + j];
416 }
417 }
418 } else {
419 for (short i = 0; i < BN; i += TROWS) {
420 if ((read_n + i) < params->O) {
422 for (short j = 0; j < vec_size; j++) {
423 dst[i * dst_ld + j] = src[i * src_ld + j];
424 }
425 } else {
427 for (short j = 0; j < vec_size; j++) {
428 dst[i * dst_ld + j] = T(0);
429 }
430 }
431 }
432 }
433 }
434
435 /* Iteration helper */
436 METAL_FUNC void next() {
437 if (++weight_hw < (params->wS[1] * params->wS[0])) {
438 src += params->wt_strides[2];
439 return;
440 }
441
442 weight_hw = 0;
443
444 src += BK - (params->wS[1] * params->wS[0] - 1) * params->wt_strides[2];
445 }
446};
447
448} // namespace steel
449} // namespace mlx
Definition allocator.h:7
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
#define STEEL_CONST
Definition defines.h:3
Definition params.h:6
const int oS[NDIM]
Definition params.h:12
const int iS[NDIM]
Definition params.h:10
const int kdil[NDIM]
Definition params.h:15
const int str[NDIM]
Definition params.h:13
const size_t wt_strides[NDIM+2]
Definition params.h:18
const bool flip
Definition params.h:21
const size_t in_strides[NDIM+2]
Definition params.h:17
const int wS[NDIM]
Definition params.h:11
const int O
Definition params.h:9
const int N
Definition params.h:7
const int pad[NDIM]
Definition params.h:14
Definition loader_channel_l.h:23
STEEL_CONST short n_rows
Definition loader_channel_l.h:37
const constant MLXConvParams< 2 > * params
Definition loader_channel_l.h:47
STEEL_CONST short TCOLS
Definition loader_channel_l.h:33
int read_iw[n_rows]
Definition loader_channel_l.h:57
STEEL_CONST short TROWS
Definition loader_channel_l.h:34
STEEL_CONST short BCOLS
Definition loader_channel_l.h:26
METAL_FUNC void next()
Definition loader_channel_l.h:134
short weight_h
Definition loader_channel_l.h:50
const device T * src[n_rows]
Definition loader_channel_l.h:53
const short thread_idx
Definition loader_channel_l.h:40
const short bj
Definition loader_channel_l.h:42
int read_ih[n_rows]
Definition loader_channel_l.h:56
METAL_FUNC Conv2DInputBlockLoaderLargeFilter(const device T *src_, threadgroup T *dst_, const int2 offsets, const constant MLXConvParams< 2 > *params_, const constant ImplicitGemmConv2DParams *gemm_params_, uint simd_group_id, uint simd_lane_id)
Definition loader_channel_l.h:60
const short bi
Definition loader_channel_l.h:41
STEEL_CONST short dst_ld
Definition loader_channel_l.h:29
METAL_FUNC void load_unsafe() const
Definition loader_channel_l.h:106
const constant ImplicitGemmConv2DParams * gemm_params
Definition loader_channel_l.h:48
STEEL_CONST short BROWS
Definition loader_channel_l.h:25
STEEL_CONST short vec_size
Definition loader_channel_l.h:30
short weight_w
Definition loader_channel_l.h:51
threadgroup T * dst
Definition loader_channel_l.h:45
int read_n[n_rows]
Definition loader_channel_l.h:55
Definition loader_channel_l.h:171
METAL_FUNC Conv2DInputBlockLoaderSmallFilter(const device T *src_, threadgroup T *dst_, const int2 offsets, const constant MLXConvParams< 2 > *params_, const constant ImplicitGemmConv2DParams *gemm_params_, uint simd_group_id, uint simd_lane_id)
Definition loader_channel_l.h:209
mask_t mask_h[n_rows]
Definition loader_channel_l.h:205
STEEL_CONST short BROWS
Definition loader_channel_l.h:173
mask_t mask_w[n_rows]
Definition loader_channel_l.h:206
short mask_t
Definition loader_channel_l.h:187
short weight_h
Definition loader_channel_l.h:200
STEEL_CONST short TROWS
Definition loader_channel_l.h:182
STEEL_CONST short n_rows
Definition loader_channel_l.h:185
short weight_w
Definition loader_channel_l.h:201
const constant MLXConvParams< 2 > * params
Definition loader_channel_l.h:197
const device T * src[n_rows]
Definition loader_channel_l.h:203
STEEL_CONST short TCOLS
Definition loader_channel_l.h:181
const short bj
Definition loader_channel_l.h:192
STEEL_CONST short vec_size
Definition loader_channel_l.h:178
METAL_FUNC void next()
Definition loader_channel_l.h:315
METAL_FUNC void load_unsafe() const
Definition loader_channel_l.h:290
threadgroup T * dst
Definition loader_channel_l.h:195
STEEL_CONST short dst_ld
Definition loader_channel_l.h:177
const short thread_idx
Definition loader_channel_l.h:190
STEEL_CONST short BCOLS
Definition loader_channel_l.h:174
const constant ImplicitGemmConv2DParams * gemm_params
Definition loader_channel_l.h:198
const short bi
Definition loader_channel_l.h:191
Definition loader_channel_l.h:352
STEEL_CONST short dst_ld
Definition loader_channel_l.h:358
STEEL_CONST short vec_size
Definition loader_channel_l.h:359
const bool do_read
Definition loader_channel_l.h:386
const constant MLXConvParams< 2 > * params
Definition loader_channel_l.h:381
STEEL_CONST short n_rows
Definition loader_channel_l.h:367
const int read_n
Definition loader_channel_l.h:385
METAL_FUNC void load_unsafe() const
Definition loader_channel_l.h:409
const short bj
Definition loader_channel_l.h:375
const int src_ld
Definition loader_channel_l.h:370
const device T * src
Definition loader_channel_l.h:379
STEEL_CONST short TCOLS
Definition loader_channel_l.h:363
STEEL_CONST short BCOLS
Definition loader_channel_l.h:355
const short bi
Definition loader_channel_l.h:374
STEEL_CONST short TROWS
Definition loader_channel_l.h:364
METAL_FUNC Conv2DWeightBlockLoader(const device T *src_, threadgroup T *dst_, const int2 offsets, const constant MLXConvParams< 2 > *params_, const constant ImplicitGemmConv2DParams *gemm_params_, uint simd_group_id, uint simd_lane_id)
Definition loader_channel_l.h:389
METAL_FUNC void next()
Definition loader_channel_l.h:436
const short thread_idx
Definition loader_channel_l.h:373
int weight_hw
Definition loader_channel_l.h:383
STEEL_CONST short BROWS
Definition loader_channel_l.h:354
threadgroup T * dst
Definition loader_channel_l.h:378
const int inp_jump_h
Definition params.h:35
const int inp_jump_c
Definition params.h:36
const int inp_jump_w
Definition params.h:34