MLX
 
Loading...
Searching...
No Matches
loader.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
6
8// Loading helper
10
11namespace mlx {
12namespace steel {
13
14template <
15 typename T,
16 short BROWS,
17 short BCOLS,
18 short dst_ld,
19 short reduction_dim,
20 short tgp_size,
21 short alignment = 1,
22 short n_reads = (BCOLS * BROWS) / (tgp_size),
23 short TCOLS = BCOLS / n_reads,
24 short TROWS = tgp_size / TCOLS>
26 STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
27 STEEL_CONST short vec_size = n_reads;
28
29 // Leading dimension for src
30 const int src_ld;
31 const int tile_stride;
32
33 // Thread location indices
34 const short thread_idx;
35 const short bi;
36 const short bj;
37
38 // threadgroup and device memory
39 threadgroup T* dst;
40 const device T* src;
41
42 struct alignas(alignment * sizeof(T)) ReadVector {
43 uint8_t v[sizeof(T) * vec_size];
44 };
45
46 /* Constructor */
47 METAL_FUNC BlockLoader(
48 const device T* src_,
49 const int src_ld_,
50 threadgroup T* dst_,
51 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
52 ushort simd_lane_id [[thread_index_in_simdgroup]])
53 : src_ld(src_ld_),
54 tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
55 thread_idx(simd_group_id * 32 + simd_lane_id),
56 bi(thread_idx / TCOLS),
57 bj(vec_size * (thread_idx % TCOLS)),
58 dst(dst_ + bi * dst_ld + bj),
59 src(src_ + bi * src_ld + bj) {}
60
61 /* Apply operation to threadgroup without bound checking */
62 template <typename UnaryOp>
63 METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
65 for (short i = 0; i < BROWS; i += TROWS) {
67 for (short j = 0; j < vec_size; j++) {
68 dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]);
69 }
70 }
71 }
72
73 /* Load from device memory into threadgroup memory - without bound checking */
74 METAL_FUNC void load_unsafe() const {
76 for (short i = 0; i < BROWS; i += TROWS) {
77 *((threadgroup ReadVector*)(&dst[i * dst_ld])) =
78 *((const device ReadVector*)(&src[i * src_ld]));
79 }
80 }
81
82 /* Load from device memory into threadgroup memory - with bound checking */
83 METAL_FUNC void load_safe(short2 src_tile_dim) const {
84 src_tile_dim = src_tile_dim - short2(bj, bi);
85
86 // Skip loading if thread has no valid reads
87 if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
89 for (short i = 0; i < BROWS; i += TROWS) {
91 for (short j = 0; j < vec_size; j++) {
92 dst[i * dst_ld + j] = T(0);
93 }
94 }
95 return;
96 }
97
98 // Use fast thread memory for bound checks
99 bool tmp_idx[vec_size];
100 T tmp_val[vec_size];
101
103 for (short i = 0; i < BROWS; i += TROWS) {
104 // Make sure tmp_idx only contains valid indices
106 for (short j = 0; j < vec_size; j++) {
107 tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
108 }
109
110 // Read valid indices into tmp_val
112 for (short j = 0; j < vec_size; j++) {
113 tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
114 }
115
116 // Zero out uneeded values
118 for (short j = 0; j < vec_size; j++) {
119 tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
120 }
121
122 // Copy values to threadgroup memory
124 for (short j = 0; j < vec_size; j++) {
125 dst[i * dst_ld + j] = tmp_val[j];
126 }
127 }
128 }
129
130 /* Iteration helper */
131 METAL_FUNC void next() {
132 src += tile_stride;
133 }
134};
135
136template <int R, int C>
137struct CShape {
140};
141
142template <
143 typename T,
144 short BROWS,
145 short BCOLS,
146 short kDstStrRow,
147 short kDstStrCol,
148 short reduction_dim,
149 short tgp_size,
150 short n_reads = (BCOLS * BROWS) / (tgp_size),
151 short TCOLS = BCOLS / n_reads,
152 short TROWS = tgp_size / TCOLS>
154 STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
155 STEEL_CONST short vec_size = n_reads;
156
157 // Leading dimension for src
158 const int src_ld;
159 const int tile_stride;
160
161 // Thread location indices
162 const short thread_idx;
163 const short bi;
164 const short bj;
165
166 // threadgroup and device memory
167 threadgroup T* dst;
168 const device T* src;
169
170 /* Constructor */
171 METAL_FUNC BlockLoaderT(
172 const device T* src_,
173 const int src_ld_,
174 threadgroup T* dst_,
175 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
176 ushort simd_lane_id [[thread_index_in_simdgroup]])
177 : src_ld(src_ld_),
178 tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
179 thread_idx(simd_group_id * 32 + simd_lane_id),
180 bi(thread_idx / TCOLS),
181 bj(vec_size * (thread_idx % TCOLS)),
182 dst(dst_ + bi * kDstStrRow + bj * kDstStrCol),
183 src(src_ + bi * src_ld + bj) {}
184
185 /* Apply operation to threadgroup without bound checking */
186 template <typename UnaryOp>
187 METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
189 for (short i = 0; i < BROWS; i += TROWS) {
191 for (short j = 0; j < vec_size; j++) {
192 dst[i * kDstStrRow + j * kDstStrCol] =
193 op.apply(dst[i * kDstStrRow + j * kDstStrCol]);
194 }
195 }
196 }
197
198 /* Load from device memory into threadgroup memory - without bound checking */
199 METAL_FUNC void load_unsafe() const {
201 for (short i = 0; i < BROWS; i += TROWS) {
203 for (short j = 0; j < vec_size; j++) {
204 dst[i * kDstStrRow + j * kDstStrCol] = src[i * src_ld + j];
205 }
206 }
207 }
208
209 /* Load from device memory into threadgroup memory - with bound checking */
210 METAL_FUNC void load_safe(short2 src_tile_dim) const {
211 src_tile_dim = src_tile_dim - short2(bj, bi);
212
213 // Skip loading if thread has no valid reads
214 if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
216 for (short i = 0; i < BROWS; i += TROWS) {
218 for (short j = 0; j < vec_size; j++) {
219 dst[i * kDstStrRow + j * kDstStrCol] = T(0);
220 }
221 }
222 return;
223 }
224
225 // Use fast thread memory for bound checks
226 bool tmp_idx[vec_size];
227 T tmp_val[vec_size];
228
230 for (short i = 0; i < BROWS; i += TROWS) {
231 // Make sure tmp_idx only contains valid indices
233 for (short j = 0; j < vec_size; j++) {
234 tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
235 }
236
237 // Read valid indices into tmp_val
239 for (short j = 0; j < vec_size; j++) {
240 tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
241 }
242
243 // Zero out uneeded values
245 for (short j = 0; j < vec_size; j++) {
246 tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
247 }
248
249 // Copy values to threadgroup memory
251 for (short j = 0; j < vec_size; j++) {
252 dst[i * kDstStrRow + j * kDstStrCol] = tmp_val[j];
253 }
254 }
255 }
256
257 /* Iteration helper */
258 METAL_FUNC void next() {
259 src += tile_stride;
260 }
261};
262
263} // namespace steel
264} // namespace mlx
Definition attn.h:19
Definition allocator.h:7
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
#define STEEL_CONST
Definition defines.h:3
uint8_t v[sizeof(T) *vec_size]
Definition loader.h:43
METAL_FUNC BlockLoader(const device T *src_, const int src_ld_, threadgroup T *dst_, ushort simd_group_id, ushort simd_lane_id)
Definition loader.h:47
METAL_FUNC void next()
Definition loader.h:131
METAL_FUNC void load_unsafe() const
Definition loader.h:74
METAL_FUNC void load_safe(short2 src_tile_dim) const
Definition loader.h:83
METAL_FUNC void apply_inplace_op(thread const UnaryOp &op) const
Definition loader.h:63
METAL_FUNC BlockLoaderT(const device T *src_, const int src_ld_, threadgroup T *dst_, ushort simd_group_id, ushort simd_lane_id)
Definition loader.h:171
STEEL_CONST short n_rows
Definition loader.h:154
METAL_FUNC void apply_inplace_op(thread const UnaryOp &op) const
Definition loader.h:187
const int tile_stride
Definition loader.h:159
METAL_FUNC void next()
Definition loader.h:258
const short bi
Definition loader.h:163
threadgroup T * dst
Definition loader.h:167
const device T * src
Definition loader.h:168
STEEL_CONST short vec_size
Definition loader.h:155
METAL_FUNC void load_safe(short2 src_tile_dim) const
Definition loader.h:210
const short bj
Definition loader.h:164
METAL_FUNC void load_unsafe() const
Definition loader.h:199
const int src_ld
Definition loader.h:158
const short thread_idx
Definition loader.h:162
Definition loader.h:137
STEEL_CONST int kCols
Definition loader.h:139
STEEL_CONST int kRows
Definition loader.h:138