MLX
 
Loading...
Searching...
No Matches
steel_attention.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3using namespace mlx::steel;
4
6// GEMM kernels
8
9constant bool align_Q [[function_constant(200)]];
10constant bool align_K [[function_constant(201)]];
11
12template <typename T>
15 METAL_FUNC TransformScale(T scale_) : scale(scale_) {}
16
17 METAL_FUNC T apply(T x) const {
18 return scale * x;
19 }
20};
21
22struct MaxOp {
23 template <typename T>
24 METAL_FUNC static constexpr T apply(T x, T y) {
25 return metal::max(x, y);
26 }
27};
28
29struct SumOp {
30 template <typename T>
31 METAL_FUNC static constexpr T apply(T x, T y) {
32 return x + y;
33 }
34};
35
36struct MulOp {
37 template <typename T>
38 METAL_FUNC static constexpr T apply(T x, T y) {
39 return x * y;
40 }
41};
42
43struct SubOp {
44 template <typename T>
45 METAL_FUNC static constexpr T apply(T x, T y) {
46 return x - y;
47 }
48};
49
50struct ExpSubOp {
51 template <typename T>
52 METAL_FUNC static constexpr T apply(T x, T y) {
53 return fast::exp2(x - y);
54 }
55};
56
57struct DivOp {
58 template <typename T>
59 METAL_FUNC static constexpr T apply(T x, T y) {
60 return x / y;
61 }
62};
63
64// clang-format off
65template <
66 typename T,
67 int BQ,
68 int BK,
69 int BD,
70 int WM,
71 int WN,
72 typename AccumType = float>
73[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention(
74 const device T* Q [[buffer(0)]],
75 const device T* K [[buffer(1)]],
76 const device T* V [[buffer(2)]],
77 device T* O [[buffer(3)]],
78 const constant AttnParams* params [[buffer(4)]],
79 uint simd_lane_id [[thread_index_in_simdgroup]],
80 uint simd_group_id [[simdgroup_index_in_threadgroup]],
81 uint3 tid [[threadgroup_position_in_grid]],
82 uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
83
84 // Pacifying compiler
85 (void)lid;
86
87 // Move to correct block
88 ulong3 tidl{tid.x, tid.y, tid.z};
89
90 Q += tidl.z * params->Q_strides[0] + // Batch
91 tidl.y * params->Q_strides[1] + // Head
92 tidl.x * BQ * params->Q_strides[2]; // Seqeunce
93
94 ulong kv_head_idx = int(tid.y) / params->gqa_factor;
95 K += tidl.z * params->K_strides[0] + // Batch
96 kv_head_idx * params->K_strides[1]; // Head
97
98 V += tidl.z * params->V_strides[0] + // Batch
99 kv_head_idx * params->V_strides[1]; // Head
100
101 O += tidl.z * params->O_strides[0] + // Batch
102 tidl.y * params->O_strides[1] + // Head
103 tidl.x * BQ * params->O_strides[2]; // Seqeunce
104
105 // Prepare threadgroup memory
106 constexpr short padQ = 16 / sizeof(T);
107 constexpr short padK = 16 / sizeof(T);
108 constexpr short padV = 16 / sizeof(T);
109
110 constexpr short LDQ_tgp = BD + padQ;
111 constexpr short LDK_tgp = BK + padK;
112 constexpr short LDV_tgp = BD + padV;
113
114 constexpr short tgp_mem_0 = (BK + padK) * (BD);
115 constexpr short tgp_mem_1 = BK * (BD + padV);
116 constexpr short tgp_mem_s = tgp_mem_0 > tgp_mem_1 ? tgp_mem_0 : tgp_mem_1;
117
118 threadgroup T Q_smem[BQ * (BD + padQ)];
119 threadgroup T KV_smem[tgp_mem_s];
120
121 threadgroup T* Qs = Q_smem;
122 threadgroup T* Ks = KV_smem;
123 threadgroup T* Vs = KV_smem;
124
125 // Prepare block loaders
126 using QBlockLoader = BlockLoaderT<
127 /* typename T = */ T,
128 /* short BROWS = */ BQ,
129 /* short BCOLS = */ BD,
130 /* short kDstStrRow = */ LDQ_tgp,
131 /* short kDstStrCol = */ 1,
132 /* short reduction_dim = */ 1,
133 /* short tgp_size = */ WM * WN * 32>;
134
135 // K is loaded in transposed
136 using KBlockLoader = BlockLoaderT<
137 /* typename T = */ T,
138 /* short BROWS = */ BK,
139 /* short BCOLS = */ BD,
140 /* short kDstStrRow = */ 1,
141 /* short kDstStrCol = */ LDK_tgp,
142 /* short reduction_dim = */ 0,
143 /* short tgp_size = */ WM * WN * 32>;
144
145 using VBlockLoader = BlockLoaderT<
146 /* typename T = */ T,
147 /* short BROWS = */ BK,
148 /* short BCOLS = */ BD,
149 /* short kDstStrRow = */ LDV_tgp,
150 /* short kDstStrCol = */ 1,
151 /* short reduction_dim = */ 0,
152 /* short tgp_size = */ WM * WN * 32>;
153
154 QBlockLoader loader_q(
155 Q, params->Q_strides[2], Qs, simd_group_id, simd_lane_id);
156 KBlockLoader loader_k(
157 K, params->K_strides[2], Ks, simd_group_id, simd_lane_id);
158 VBlockLoader loader_v(
159 V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
160
161 TransformScale<T> ts(static_cast<T>(params->scale * 1.44269504089));
162
163 // Prepare MMA tiles
164 constexpr short kFragSize = 8; // MMAFrag size
166
167 constexpr int kNWarps = WM * WN;
168 static_assert(
169 BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0,
170 "Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.");
171
172 // Q seq frags per warp
173 constexpr int TQ = BQ / (kNWarps * kFragSize);
174 // KV sequence frags (all warps load the same frags)
175 constexpr int TK = BK / kFragSize;
176 // HeadDim frags (all warps load the same frags)
177 constexpr int TD = BD / kFragSize;
178
179 static_assert(TQ == 1, "Check TQ");
180
186
187 Otile.clear();
188
189 // Prepare mma tile offsets
190 const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
191 const short sm = simd_coord.y;
192 const short sn = simd_coord.x;
193 const short tm = kFragSize * TQ * simd_group_id;
194
195 const short Qs_offset = (tm + sm) * LDQ_tgp + sn;
196 const short Ks_offset = sm * LDK_tgp + sn;
197 const short Vs_offset = sm * LDV_tgp + sn;
198
199 constexpr short Qs_tile_stride = kFragSize;
200 constexpr short Ks_tile_stride = kFragSize * LDK_tgp;
201
202 threadgroup_barrier(mem_flags::mem_threadgroup);
203
204 // Load Q blocks apply scale
205 if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
206 loader_q.load_safe(short2(BD, params->qL - params->NQ_aligned * BQ));
207 } else {
208 loader_q.load_unsafe();
209 }
210 loader_q.apply_inplace_op(ts);
211
212 // Init row reduction variables
213 constexpr short kRowsPT = decltype(Stile)::kRowsPerThread;
214
215 AccumType max_score[kRowsPT];
216 AccumType sum_score[kRowsPT] = {0};
217
218 // Init to -Inf
220 for (short i = 0; i < kRowsPT; ++i) {
221 max_score[i] = Limits<AccumType>::min;
222 }
223
224 // Loop over KV seq length
225 for (int kb = 0; kb < params->NK; kb++) {
226 // Load K block and apply scale
227 threadgroup_barrier(mem_flags::mem_threadgroup);
228 if (!align_K && kb == (params->NK_aligned)) {
229 loader_k.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
230 } else {
231 loader_k.load_unsafe();
232 }
233
234 // Do S = Q @ K.T
235 Stile.clear();
236
237 threadgroup_barrier(mem_flags::mem_threadgroup);
238
240 for (short dd = 0; dd < TD; dd++) {
241 simdgroup_barrier(mem_flags::mem_none);
242
243 Qtile.template load<T, 1, 1, LDQ_tgp, 1>(
244 &Qs[Qs_offset + dd * Qs_tile_stride]);
245 Ktile.template load<T, 1, 1, LDK_tgp, 1>(
246 &Ks[Ks_offset + dd * Ks_tile_stride]);
247
248 simdgroup_barrier(mem_flags::mem_none);
249
250 tile_matmad(Stile, Qtile, Ktile, Stile);
251 }
252
253 // Mask out of length sequence
254 if (!align_K && kb == (params->NK_aligned)) {
255 using stile_t = decltype(Stile);
256 using selem_t = typename stile_t::elem_type;
257 constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
258 const short lim = params->kL - params->NK_aligned * BK;
259
261 for (short i = 0; i < stile_t::kTileRows; i++) {
263 for (short j = 0; j < stile_t::kTileCols; j++) {
264 short col_pos = sn + (j * stile_t::kFragCols);
266 for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
267 if ((col_pos + jj) >= lim) {
268 Stile.frag_at(i, j)[jj] = neg_inf;
269 }
270 }
271 }
272 }
273 }
274
275 threadgroup_barrier(mem_flags::mem_threadgroup);
276
277 // Load V blocks
278 if (!align_K && kb == (params->NK_aligned)) {
279 loader_v.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
280 } else {
281 loader_v.load_unsafe();
282 }
283
284 // Do softmax
285
286 // Temp variables
287 AccumType new_max[kRowsPT];
288 AccumType factor[kRowsPT];
290 for (short i = 0; i < kRowsPT; ++i) {
291 new_max[i] = max_score[i];
292 }
293
294 // Row max
295 Stile.template row_reduce<MaxOp>(new_max);
296
297 // exp(Si - rowmax(Si))
298 Stile.template row_bin_op<ExpSubOp>(new_max);
299
300 // Factor exp(rowmax(Si) - rowmax(Si-1))
302 for (short i = 0; i < kRowsPT; ++i) {
303 factor[i] = fast::exp2(max_score[i] - new_max[i]);
304 }
305
306 // Save max for next iteration
308 for (short i = 0; i < kRowsPT; ++i) {
309 max_score[i] = new_max[i];
310 }
311
312 // Row Sum
313 AccumType sum_score_tmp[kRowsPT] = {0};
314 Stile.template row_reduce<SumOp>(sum_score_tmp);
315
316 // Update norm
318 for (short i = 0; i < kRowsPT; ++i) {
319 sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i];
320 }
321
322 // Update O
323 Otile.template row_bin_op<MulOp>(factor);
324
325 // Load V into registers
326 threadgroup_barrier(mem_flags::mem_threadgroup);
327
329 for (short iq = 0; iq < TQ; iq++) {
331 for (short id = 0; id < TD; id++) {
333 for (short ik = 0; ik < TK; ik++) {
334 if constexpr (BD == 128) {
335 simdgroup_barrier(mem_flags::mem_none);
336 }
337
338 const short kk = ik * kFragSize;
339 const short dd = id * kFragSize;
340
341 Vtile.template load<T, 1, 1, LDV_tgp, 1>(
342 &Vs[Vs_offset + kk * LDV_tgp + dd]);
343
344 if constexpr (BD == 128) {
345 simdgroup_barrier(mem_flags::mem_none);
346 }
347
348 MMAFrag_acc_t::mma(
349 Otile.frag_at(iq, id),
350 Stile.frag_at(iq, ik),
351 Vtile.frag_at(0, 0),
352 Otile.frag_at(iq, id));
353 }
354 }
355 }
356
357 // Prepare for next iteration
358 loader_k.next();
359 loader_v.next();
360 }
361
362 // Normalize output
363 Otile.template row_bin_op<DivOp>(sum_score);
364 threadgroup_barrier(mem_flags::mem_none);
365
366 // Store results
367 O += (tm + sm) * params->O_strides[2] + sn;
368
369 if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
370 auto dst_tile_dims =
371 short2(BD - sn, params->qL - BQ * params->NQ_aligned - (tm + sm));
372
373 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
374 return;
375
376 Otile.template store_safe<T, 1, 1>(O, params->O_strides[2], dst_tile_dims);
377 } else {
378 Otile.template store<T, 1, 1>(O, params->O_strides[2]);
379 }
380}
METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:232
Definition attn.h:19
METAL_FUNC void tile_matmad(thread MMATile< Dtype, M, N, MMAFragD > &D, thread MMATile< Atype, M, K, MMAFragA > &A, thread MMATile< Btype, K, N, MMAFragB > &B, thread MMATile< Ctype, M, N, MMAFragC > &C)
Definition mma.h:432
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
constant bool align_Q
Definition steel_attention.h:9
void attention(const device T *Q, const device T *K, const device T *V, device T *O, const constant AttnParams *params, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 lid)
Definition steel_attention.h:73
constant bool align_K
Definition steel_attention.h:10
Definition steel_attention.h:57
static METAL_FUNC constexpr T apply(T x, T y)
Definition steel_attention.h:59
Definition steel_attention.h:50
static METAL_FUNC constexpr T apply(T x, T y)
Definition steel_attention.h:52
static const constant U min
Definition utils.h:25
Definition steel_attention.h:22
static METAL_FUNC constexpr T apply(T x, T y)
Definition steel_attention.h:24
Definition steel_attention.h:36
static METAL_FUNC constexpr T apply(T x, T y)
Definition steel_attention.h:38
Definition steel_attention.h:43
static METAL_FUNC constexpr T apply(T x, T y)
Definition steel_attention.h:45
Definition steel_attention.h:29
static METAL_FUNC constexpr T apply(T x, T y)
Definition steel_attention.h:31
Definition steel_attention.h:13
METAL_FUNC T apply(T x) const
Definition steel_attention.h:17
T scale
Definition steel_attention.h:14
METAL_FUNC TransformScale(T scale_)
Definition steel_attention.h:15
Definition params.h:12
Definition mma.h:37
Definition loader.h:153
Definition mma.h:231
METAL_FUNC constexpr thread frag_type & frag_at(const short i, const short j)
Definition mma.h:264
METAL_FUNC constexpr void clear()
Definition mma.h:257