MLX
 
Loading...
Searching...
No Matches
loader_general.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
6
8// Loading helper
10
11namespace mlx {
12namespace steel {
13
14template <
15 typename T,
16 short BM,
17 short BN,
18 short BK,
19 short tgp_size,
20 short tgp_padding = 0>
22 // Destination dimensions
23 STEEL_CONST short BROWS = BM;
24 STEEL_CONST short BCOLS = BK;
25
26 // Read dimensions
27 STEEL_CONST short dst_ld = BCOLS + tgp_padding;
28 STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4;
29
30 // Thread read shape
32 STEEL_CONST short TROWS = tgp_size / TCOLS;
33
34 // Rows / strided reads within the block
36
37 // Thread location indices
38 const short thread_idx;
39 const short bi;
40 const short bj;
41
42 // threadgroup and device memory
43 threadgroup T* dst;
44
45 const constant MLXConvParams<2>* params;
47
48 const short base_wh;
49 const short base_ww;
50
51 short weight_h;
52 short weight_w;
53
54 const device T* src[n_rows];
55
59
60 /* Constructor */
62 const device T* src_,
63 threadgroup T* dst_,
64 const int4 offsets,
65 const constant MLXConvParams<2>* params_,
66 const constant Conv2DGeneralJumpParams* jump_params_,
67 const short base_wh_,
68 const short base_ww_,
69 uint simd_group_id [[simdgroup_index_in_threadgroup]],
70 uint simd_lane_id [[thread_index_in_simdgroup]])
71 : thread_idx(simd_group_id * 32 + simd_lane_id),
74 dst(dst_ + bi * dst_ld + bj),
75 params(params_),
76 jump_params(jump_params_),
77 base_wh(base_wh_),
78 base_ww(base_ww_),
79 weight_h(base_wh_),
80 weight_w(base_ww_) {
82 for (short i = 0; i < n_rows; ++i) {
83 int offset_nhw = offsets.y + bi + i * TROWS;
84 int n = offset_nhw / jump_params->adj_out_hw;
85 int hw = offset_nhw % jump_params->adj_out_hw;
86 int oh =
87 (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + offsets.z;
88 int ow =
89 (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + offsets.w;
90
91 int ih = oh * params->str[0] - params->pad[0];
92 int iw = ow * params->str[1] - params->pad[1];
93
94 read_n[i] = n;
95 read_ih[i] = ih;
96 read_iw[i] = iw;
97
98 // Read from input if in bounds
99 src[i] = src_ + n * params->in_strides[0] + bj;
100 }
101 }
102
103 /* Load from device memory into threadgroup memory - without bound checking */
104 METAL_FUNC void load_unsafe() const {
106 for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
107 // Find bounds
108 int n = read_n[i];
109
110 int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h;
111 int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w;
112
113 int ih_dil = read_ih[i] + h_flip * params->kdil[0];
114 int iw_dil = read_iw[i] + w_flip * params->kdil[1];
115
116 int ih = ih_dil / params->idil[0];
117 int iw = iw_dil / params->idil[1];
118
119 size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2];
120
121 // Read from input if in bounds
122 if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) &&
123 (iw_dil >= 0 && iw < params->iS[1])) {
125 for (short j = 0; j < vec_size; ++j) {
126 dst[is * dst_ld + j] = (src[i])[offset + j];
127 }
128 }
129
130 // Zero pad otherwise
131 else {
133 for (short j = 0; j < vec_size; ++j) {
134 dst[is * dst_ld + j] = T(0);
135 }
136 }
137 }
138 }
139
140 /* Iteration helper */
141 METAL_FUNC void next() {
142 weight_w += jump_params->f_wgt_jump_w;
143 if (weight_w < params->wS[1]) {
144 return;
145 }
146
148
149 weight_h += jump_params->f_wgt_jump_h;
150 if (weight_h < params->wS[0]) {
151 return;
152 }
153
155
157 for (short i = 0; i < n_rows; i++) {
158 src[i] += BK;
159 }
160 }
161};
162
163template <
164 typename T,
165 short BM,
166 short BN,
167 short BK,
168 short tgp_size,
169 short tgp_padding = 0>
171 // Destination dimensions
172 STEEL_CONST short BROWS = BN;
173 STEEL_CONST short BCOLS = BK;
174
175 // Read dimensions
176 STEEL_CONST short dst_ld = BCOLS + tgp_padding;
178 (BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4);
179
180 // Thread read shape
182 STEEL_CONST short TROWS = tgp_size / TCOLS;
183
184 // Rows / strided reads within the block
186
187 // Leading dimension for src
188 const int src_ld;
189
190 // Thread location indices
191 const short thread_idx;
192 const short bi;
193 const short bj;
194
195 // threadgroup and device memory
196 threadgroup T* dst;
197 const device T* src;
198
199 const constant MLXConvParams<2>* params;
201
202 const short base_wh;
203 const short base_ww;
204
205 short weight_h;
206 short weight_w;
207
208 const int start_row;
209
210 /* Constructor */
212 const device T* src_,
213 threadgroup T* dst_,
214 const int2 offsets,
215 const constant MLXConvParams<2>* params_,
216 const constant Conv2DGeneralJumpParams* jump_params_,
217 const short base_wh_,
218 const short base_ww_,
219 uint simd_group_id [[simdgroup_index_in_threadgroup]],
220 uint simd_lane_id [[thread_index_in_simdgroup]])
221 : src_ld(params_->wt_strides[0]),
222 thread_idx(simd_group_id * 32 + simd_lane_id),
225 dst(dst_ + bi * dst_ld + bj),
226 src(src_ + bi * src_ld + bj),
227 params(params_),
228 jump_params(jump_params_),
229 base_wh(base_wh_),
230 base_ww(base_ww_),
231 weight_h(base_wh_),
232 weight_w(base_ww_),
233 start_row(offsets.y + bi) {}
234
235 /* Load from device memory into threadgroup memory - without bound checking */
236 METAL_FUNC void load_unsafe() const {
237 const device T* curr_src = src + weight_h * params->wt_strides[1] +
238 weight_w * params->wt_strides[2];
239
240 if ((start_row + BN <= params->O)) {
242 for (short i = 0; i < BN; i += TROWS) {
244 for (short j = 0; j < vec_size; j++) {
245 dst[i * dst_ld + j] = curr_src[i * src_ld + j];
246 }
247 }
248 } else {
249 for (short i = 0; i < BN; i += TROWS) {
250 if ((start_row + i) < params->O) {
252 for (short j = 0; j < vec_size; j++) {
253 dst[i * dst_ld + j] = curr_src[i * src_ld + j];
254 }
255 } else {
257 for (short j = 0; j < vec_size; j++) {
258 dst[i * dst_ld + j] = T(0);
259 }
260 }
261 }
262 }
263 }
264
265 /* Iteration helper */
266 METAL_FUNC void next() {
267 weight_w += jump_params->f_wgt_jump_w;
268 if (weight_w < params->wS[1]) {
269 return;
270 }
271
273
274 weight_h += jump_params->f_wgt_jump_h;
275 if (weight_h < params->wS[0]) {
276 return;
277 }
278
280
281 src += BK;
282 }
283};
284
285} // namespace steel
286} // namespace mlx
Definition attn.h:19
Definition allocator.h:7
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
#define STEEL_CONST
Definition defines.h:3
Definition params.h:6
short weight_w
Definition loader_general.h:52
STEEL_CONST short dst_ld
Definition loader_general.h:27
const constant MLXConvParams< 2 > * params
Definition loader_general.h:45
STEEL_CONST short vec_size
Definition loader_general.h:28
METAL_FUNC Conv2DInputBlockLoaderGeneral(const device T *src_, threadgroup T *dst_, const int4 offsets, const constant MLXConvParams< 2 > *params_, const constant Conv2DGeneralJumpParams *jump_params_, const short base_wh_, const short base_ww_, uint simd_group_id, uint simd_lane_id)
Definition loader_general.h:61
const device T * src[n_rows]
Definition loader_general.h:54
const constant Conv2DGeneralJumpParams * jump_params
Definition loader_general.h:46
STEEL_CONST short TROWS
Definition loader_general.h:32
const short bi
Definition loader_general.h:39
const short base_ww
Definition loader_general.h:49
int read_ih[n_rows]
Definition loader_general.h:57
METAL_FUNC void load_unsafe() const
Definition loader_general.h:104
short weight_h
Definition loader_general.h:51
STEEL_CONST short BCOLS
Definition loader_general.h:24
METAL_FUNC void next()
Definition loader_general.h:141
const short thread_idx
Definition loader_general.h:38
int read_iw[n_rows]
Definition loader_general.h:58
threadgroup T * dst
Definition loader_general.h:43
STEEL_CONST short BROWS
Definition loader_general.h:23
STEEL_CONST short n_rows
Definition loader_general.h:35
const short base_wh
Definition loader_general.h:48
const short bj
Definition loader_general.h:40
STEEL_CONST short TCOLS
Definition loader_general.h:31
int read_n[n_rows]
Definition loader_general.h:56
STEEL_CONST short BROWS
Definition loader_general.h:172
const short thread_idx
Definition loader_general.h:191
STEEL_CONST short vec_size
Definition loader_general.h:177
METAL_FUNC void next()
Definition loader_general.h:266
STEEL_CONST short BCOLS
Definition loader_general.h:173
const int start_row
Definition loader_general.h:208
const short base_ww
Definition loader_general.h:203
const short bi
Definition loader_general.h:192
const device T * src
Definition loader_general.h:197
short weight_h
Definition loader_general.h:205
const int src_ld
Definition loader_general.h:188
const short base_wh
Definition loader_general.h:202
short weight_w
Definition loader_general.h:206
threadgroup T * dst
Definition loader_general.h:196
METAL_FUNC void load_unsafe() const
Definition loader_general.h:236
const constant Conv2DGeneralJumpParams * jump_params
Definition loader_general.h:200
STEEL_CONST short dst_ld
Definition loader_general.h:176
STEEL_CONST short n_rows
Definition loader_general.h:185
STEEL_CONST short TROWS
Definition loader_general.h:182
const short bj
Definition loader_general.h:193
METAL_FUNC Conv2DWeightBlockLoaderGeneral(const device T *src_, threadgroup T *dst_, const int2 offsets, const constant MLXConvParams< 2 > *params_, const constant Conv2DGeneralJumpParams *jump_params_, const short base_wh_, const short base_ww_, uint simd_group_id, uint simd_lane_id)
Definition loader_general.h:211
const constant MLXConvParams< 2 > * params
Definition loader_general.h:199
STEEL_CONST short TCOLS
Definition loader_general.h:181