MLX
Loading...
Searching...
No Matches
loader_general.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
6
8
10// Loading helper
12
13namespace mlx {
14namespace steel {
15
16template <
17 typename T,
18 short BM,
19 short BN,
20 short BK,
21 short tgp_size,
22 short tgp_padding = 0>
24 // Destination dimensions
25 STEEL_CONST short BROWS = BM;
26 STEEL_CONST short BCOLS = BK;
27
28 // Read dimensions
29 STEEL_CONST short dst_ld = BCOLS + tgp_padding;
30 STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4;
31
32 // Thread read shape
34 STEEL_CONST short TROWS = tgp_size / TCOLS;
35
36 // Rows / strided reads within the block
38
39 // Thread location indices
40 const short thread_idx;
41 const short bi;
42 const short bj;
43
44 // threadgroup and device memory
45 threadgroup T* dst;
46
47 const constant MLXConvParams<2>* params;
49
50 const short base_wh;
51 const short base_ww;
52
53 short weight_h;
54 short weight_w;
55
56 const device T* src[n_rows];
57
61
62 /* Constructor */
64 const device T* src_,
65 threadgroup T* dst_,
66 const int4 offsets,
67 const constant MLXConvParams<2>* params_,
68 const constant Conv2DGeneralJumpParams* jump_params_,
69 const short base_wh_,
70 const short base_ww_,
71 uint simd_group_id [[simdgroup_index_in_threadgroup]],
72 uint simd_lane_id [[thread_index_in_simdgroup]])
73 : thread_idx(simd_group_id * 32 + simd_lane_id),
76 dst(dst_ + bi * dst_ld + bj),
77 params(params_),
78 jump_params(jump_params_),
79 base_wh(base_wh_),
80 base_ww(base_ww_),
81 weight_h(base_wh_),
82 weight_w(base_ww_) {
84 for (short i = 0; i < n_rows; ++i) {
85 int offset_nhw = offsets.y + bi + i * TROWS;
86 int n = offset_nhw / jump_params->adj_out_hw;
87 int hw = offset_nhw % jump_params->adj_out_hw;
88 int oh =
89 (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + offsets.z;
90 int ow =
91 (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + offsets.w;
92
93 int ih = oh * params->str[0] - params->pad[0];
94 int iw = ow * params->str[1] - params->pad[1];
95
96 read_n[i] = n;
97 read_ih[i] = ih;
98 read_iw[i] = iw;
99
100 // Read from input if in bounds
101 src[i] = src_ + n * params->in_strides[0] + bj;
102 }
103 }
104
105 /* Load from device memory into threadgroup memory - without bound checking */
106 METAL_FUNC void load_unsafe() const {
108 for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
109 // Find bounds
110 int n = read_n[i];
111
112 int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h;
113 int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w;
114
115 int ih_dil = read_ih[i] + h_flip * params->kdil[0];
116 int iw_dil = read_iw[i] + w_flip * params->kdil[1];
117
118 int ih = ih_dil / params->idil[0];
119 int iw = iw_dil / params->idil[1];
120
121 size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2];
122
123 // Read from input if in bounds
124 if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) &&
125 (iw_dil >= 0 && iw < params->iS[1])) {
127 for (short j = 0; j < vec_size; ++j) {
128 dst[is * dst_ld + j] = (src[i])[offset + j];
129 }
130 }
131
132 // Zero pad otherwise
133 else {
135 for (short j = 0; j < vec_size; ++j) {
136 dst[is * dst_ld + j] = T(0);
137 }
138 }
139 }
140 }
141
142 /* Iteration helper */
143 METAL_FUNC void next() {
145 if (weight_w < params->wS[1]) {
146 return;
147 }
148
150
152 if (weight_h < params->wS[0]) {
153 return;
154 }
155
157
159 for (short i = 0; i < n_rows; i++) {
160 src[i] += BK;
161 }
162 }
163};
164
165template <
166 typename T,
167 short BM,
168 short BN,
169 short BK,
170 short tgp_size,
171 short tgp_padding = 0>
173 // Destination dimensions
174 STEEL_CONST short BROWS = BN;
175 STEEL_CONST short BCOLS = BK;
176
177 // Read dimensions
178 STEEL_CONST short dst_ld = BCOLS + tgp_padding;
180 (BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4);
181
182 // Thread read shape
184 STEEL_CONST short TROWS = tgp_size / TCOLS;
185
186 // Rows / strided reads within the block
188
189 // Leading dimension for src
190 const int src_ld;
191
192 // Thread location indices
193 const short thread_idx;
194 const short bi;
195 const short bj;
196
197 // threadgroup and device memory
198 threadgroup T* dst;
199 const device T* src;
200
201 const constant MLXConvParams<2>* params;
203
204 const short base_wh;
205 const short base_ww;
206
207 short weight_h;
208 short weight_w;
209
210 const int start_row;
211
212 /* Constructor */
214 const device T* src_,
215 threadgroup T* dst_,
216 const int2 offsets,
217 const constant MLXConvParams<2>* params_,
218 const constant Conv2DGeneralJumpParams* jump_params_,
219 const short base_wh_,
220 const short base_ww_,
221 uint simd_group_id [[simdgroup_index_in_threadgroup]],
222 uint simd_lane_id [[thread_index_in_simdgroup]])
223 : src_ld(params_ -> wt_strides[0]),
224 thread_idx(simd_group_id * 32 + simd_lane_id),
227 dst(dst_ + bi * dst_ld + bj),
228 src(src_ + bi * src_ld + bj),
229 params(params_),
230 jump_params(jump_params_),
231 base_wh(base_wh_),
232 base_ww(base_ww_),
233 weight_h(base_wh_),
234 weight_w(base_ww_),
235 start_row(offsets.y + bi) {}
236
237 /* Load from device memory into threadgroup memory - without bound checking */
238 METAL_FUNC void load_unsafe() const {
239 const device T* curr_src = src + weight_h * params->wt_strides[1] +
241
242 if ((start_row + BN <= params->O)) {
244 for (short i = 0; i < BN; i += TROWS) {
246 for (short j = 0; j < vec_size; j++) {
247 dst[i * dst_ld + j] = curr_src[i * src_ld + j];
248 }
249 }
250 } else {
251 for (short i = 0; i < BN; i += TROWS) {
252 if ((start_row + i) < params->O) {
254 for (short j = 0; j < vec_size; j++) {
255 dst[i * dst_ld + j] = curr_src[i * src_ld + j];
256 }
257 } else {
259 for (short j = 0; j < vec_size; j++) {
260 dst[i * dst_ld + j] = T(0);
261 }
262 }
263 }
264 }
265 }
266
267 /* Iteration helper */
268 METAL_FUNC void next() {
270 if (weight_w < params->wS[1]) {
271 return;
272 }
273
275
277 if (weight_h < params->wS[0]) {
278 return;
279 }
280
282
283 src += BK;
284 }
285};
286
287} // namespace steel
288} // namespace mlx
#define STEEL_PRAGMA_UNROLL
Definition utils.h:8
#define STEEL_CONST
Definition utils.h:7
Definition allocator.h:7
Definition params.h:6
const int kdil[NDIM]
Definition params.h:15
const int str[NDIM]
Definition params.h:13
const size_t wt_strides[NDIM+2]
Definition params.h:18
const bool flip
Definition params.h:21
const size_t in_strides[NDIM+2]
Definition params.h:17
const int wS[NDIM]
Definition params.h:11
const int O
Definition params.h:9
const int pad[NDIM]
Definition params.h:14
const int idil[NDIM]
Definition params.h:16
const int f_out_jump_w
Definition params.h:48
const int f_wgt_jump_h
Definition params.h:44
const int f_wgt_jump_w
Definition params.h:45
const int f_out_jump_h
Definition params.h:47
const int adj_out_w
Definition params.h:51
const int adj_out_hw
Definition params.h:52
Definition loader_general.h:23
short weight_w
Definition loader_general.h:54
STEEL_CONST short dst_ld
Definition loader_general.h:29
const constant MLXConvParams< 2 > * params
Definition loader_general.h:47
STEEL_CONST short vec_size
Definition loader_general.h:30
METAL_FUNC Conv2DInputBlockLoaderGeneral(const device T *src_, threadgroup T *dst_, const int4 offsets, const constant MLXConvParams< 2 > *params_, const constant Conv2DGeneralJumpParams *jump_params_, const short base_wh_, const short base_ww_, uint simd_group_id, uint simd_lane_id)
Definition loader_general.h:63
const device T * src[n_rows]
Definition loader_general.h:56
const constant Conv2DGeneralJumpParams * jump_params
Definition loader_general.h:48
STEEL_CONST short TROWS
Definition loader_general.h:34
const short bi
Definition loader_general.h:41
const short base_ww
Definition loader_general.h:51
int read_ih[n_rows]
Definition loader_general.h:59
METAL_FUNC void load_unsafe() const
Definition loader_general.h:106
short weight_h
Definition loader_general.h:53
STEEL_CONST short BCOLS
Definition loader_general.h:26
METAL_FUNC void next()
Definition loader_general.h:143
const short thread_idx
Definition loader_general.h:40
int read_iw[n_rows]
Definition loader_general.h:60
threadgroup T * dst
Definition loader_general.h:45
STEEL_CONST short BROWS
Definition loader_general.h:25
STEEL_CONST short n_rows
Definition loader_general.h:37
const short base_wh
Definition loader_general.h:50
const short bj
Definition loader_general.h:42
STEEL_CONST short TCOLS
Definition loader_general.h:33
int read_n[n_rows]
Definition loader_general.h:58
Definition loader_general.h:172
STEEL_CONST short BROWS
Definition loader_general.h:174
const short thread_idx
Definition loader_general.h:193
STEEL_CONST short vec_size
Definition loader_general.h:179
METAL_FUNC void next()
Definition loader_general.h:268
STEEL_CONST short BCOLS
Definition loader_general.h:175
const int start_row
Definition loader_general.h:210
const short base_ww
Definition loader_general.h:205
const short bi
Definition loader_general.h:194
const device T * src
Definition loader_general.h:199
short weight_h
Definition loader_general.h:207
const int src_ld
Definition loader_general.h:190
const short base_wh
Definition loader_general.h:204
short weight_w
Definition loader_general.h:208
threadgroup T * dst
Definition loader_general.h:198
METAL_FUNC void load_unsafe() const
Definition loader_general.h:238
const constant Conv2DGeneralJumpParams * jump_params
Definition loader_general.h:202
STEEL_CONST short dst_ld
Definition loader_general.h:178
STEEL_CONST short n_rows
Definition loader_general.h:187
STEEL_CONST short TROWS
Definition loader_general.h:184
const short bj
Definition loader_general.h:195
METAL_FUNC Conv2DWeightBlockLoaderGeneral(const device T *src_, threadgroup T *dst_, const int2 offsets, const constant MLXConvParams< 2 > *params_, const constant Conv2DGeneralJumpParams *jump_params_, const short base_wh_, const short base_ww_, uint simd_group_id, uint simd_lane_id)
Definition loader_general.h:213
const constant MLXConvParams< 2 > * params
Definition loader_general.h:201
STEEL_CONST short TCOLS
Definition loader_general.h:183