MLX
 
Loading...
Searching...
No Matches
steel_attention.h
Go to the documentation of this file.
1// Copyright © 2024-25 Apple Inc.
2
3using namespace mlx::steel;
4
6// GEMM kernels
8
9constant bool align_Q [[function_constant(200)]];
10constant bool align_K [[function_constant(201)]];
11
12constant bool has_mask [[function_constant(300)]];
13constant bool do_causal [[function_constant(301)]];
14
15template <typename T>
18 METAL_FUNC TransformScale(T scale_) : scale(scale_) {}
19
20 METAL_FUNC T apply(T x) const {
21 return scale * x;
22 }
23};
24
25struct MaxOp {
26 template <typename T>
27 METAL_FUNC static constexpr T apply(T x, T y) {
28 return metal::max(x, y);
29 }
30};
31
32struct SumOp {
33 template <typename T>
34 METAL_FUNC static constexpr T apply(T x, T y) {
35 return x + y;
36 }
37};
38
39struct MulOp {
40 template <typename T>
41 METAL_FUNC static constexpr T apply(T x, T y) {
42 return x * y;
43 }
44};
45
46struct SubOp {
47 template <typename T>
48 METAL_FUNC static constexpr T apply(T x, T y) {
49 return x - y;
50 }
51};
52
53struct ExpSubOp {
54 template <typename T>
55 METAL_FUNC static constexpr T apply(T x, T y) {
56 return fast::exp2(x - y);
57 }
58};
59
60struct DivOp {
61 template <typename T>
62 METAL_FUNC static constexpr T apply(T x, T y) {
63 return x / y;
64 }
65};
66
67// clang-format off
68template <
69 typename T,
70 int BQ,
71 int BK,
72 int BD,
73 int WM,
74 int WN,
75 typename MaskType = float,
76 typename AccumType = float>
77[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention(
78 const device T* Q [[buffer(0)]],
79 const device T* K [[buffer(1)]],
80 const device T* V [[buffer(2)]],
81 device T* O [[buffer(3)]],
82 const constant AttnParams* params [[buffer(4)]],
83 const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],
84 const device MaskType* mask [[buffer(6), function_constant(has_mask)]],
85 uint simd_lane_id [[thread_index_in_simdgroup]],
86 uint simd_group_id [[simdgroup_index_in_threadgroup]],
87 uint3 tid [[threadgroup_position_in_grid]],
88 uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
89
90 // Pacifying compiler
91 (void)lid;
92
93 // Move to correct block
94 ulong3 tidl{tid.x, tid.y, tid.z};
95
96 Q += tidl.z * params->Q_strides[0] + // Batch
97 tidl.y * params->Q_strides[1] + // Head
98 tidl.x * BQ * params->Q_strides[2]; // Seqeunce
99
100 ulong kv_head_idx = int(tid.y) / params->gqa_factor;
101 K += tidl.z * params->K_strides[0] + // Batch
102 kv_head_idx * params->K_strides[1]; // Head
103
104 V += tidl.z * params->V_strides[0] + // Batch
105 kv_head_idx * params->V_strides[1]; // Head
106
107 O += tidl.z * params->O_strides[0] + // Batch
108 tidl.y * params->O_strides[1] + // Head
109 tidl.x * BQ * params->O_strides[2]; // Seqeunce
110
111 if (has_mask) {
112 mask += tidl.z * mask_params->M_strides[0] + // Batch
113 tidl.y * mask_params->M_strides[1]; // Head
114 }
115
116 // Prepare threadgroup memory
117 constexpr short padQ = 16 / sizeof(T);
118 constexpr short padK = 16 / sizeof(T);
119 constexpr short padV = 16 / sizeof(T);
120
121 constexpr short LDQ_tgp = BD + padQ;
122 constexpr short LDK_tgp = BK + padK;
123 constexpr short LDV_tgp = BD + padV;
124
125 constexpr short tgp_mem_0 = (BK + padK) * (BD);
126 constexpr short tgp_mem_1 = BK * (BD + padV);
127 constexpr short tgp_mem_s = tgp_mem_0 > tgp_mem_1 ? tgp_mem_0 : tgp_mem_1;
128
129 threadgroup T Q_smem[BQ * (BD + padQ)];
130 threadgroup T KV_smem[tgp_mem_s];
131
132 threadgroup T* Qs = Q_smem;
133 threadgroup T* Ks = KV_smem;
134 threadgroup T* Vs = KV_smem;
135
136 // Prepare block loaders
137 using QBlockLoader = BlockLoaderT<
138 /* typename T = */ T,
139 /* short BROWS = */ BQ,
140 /* short BCOLS = */ BD,
141 /* short kDstStrRow = */ LDQ_tgp,
142 /* short kDstStrCol = */ 1,
143 /* short reduction_dim = */ 1,
144 /* short tgp_size = */ WM * WN * 32>;
145
146 // K is loaded in transposed
147 using KBlockLoader = BlockLoaderT<
148 /* typename T = */ T,
149 /* short BROWS = */ BK,
150 /* short BCOLS = */ BD,
151 /* short kDstStrRow = */ 1,
152 /* short kDstStrCol = */ LDK_tgp,
153 /* short reduction_dim = */ 0,
154 /* short tgp_size = */ WM * WN * 32>;
155
156 using VBlockLoader = BlockLoaderT<
157 /* typename T = */ T,
158 /* short BROWS = */ BK,
159 /* short BCOLS = */ BD,
160 /* short kDstStrRow = */ LDV_tgp,
161 /* short kDstStrCol = */ 1,
162 /* short reduction_dim = */ 0,
163 /* short tgp_size = */ WM * WN * 32>;
164
165 QBlockLoader loader_q(
166 Q, params->Q_strides[2], Qs, simd_group_id, simd_lane_id);
167 KBlockLoader loader_k(
168 K, params->K_strides[2], Ks, simd_group_id, simd_lane_id);
169 VBlockLoader loader_v(
170 V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
171
172 TransformScale<T> ts(static_cast<T>(params->scale * 1.44269504089));
173
174 // Prepare MMA tiles
175 constexpr short kFragSize = 8; // MMAFrag size
177
178 constexpr int kNWarps = WM * WN;
179 static_assert(
180 BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0,
181 "Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.");
182
183 // Q seq frags per warp
184 constexpr int TQ = BQ / (kNWarps * kFragSize);
185 // KV sequence frags (all warps load the same frags)
186 constexpr int TK = BK / kFragSize;
187 // HeadDim frags (all warps load the same frags)
188 constexpr int TD = BD / kFragSize;
189
190 static_assert(TQ == 1, "Check TQ");
191
197
198 Otile.clear();
199
200 // Prepare mma tile offsets
201 const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
202 const short sm = simd_coord.y;
203 const short sn = simd_coord.x;
204 const short tm = kFragSize * TQ * simd_group_id;
205
206 const short Qs_offset = (tm + sm) * LDQ_tgp + sn;
207 const short Ks_offset = sm * LDK_tgp + sn;
208 const short Vs_offset = sm * LDV_tgp + sn;
209
210 constexpr short Qs_tile_stride = kFragSize;
211 constexpr short Ks_tile_stride = kFragSize * LDK_tgp;
212
213 threadgroup_barrier(mem_flags::mem_threadgroup);
214
215 // Load Q blocks apply scale
216 if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
217 loader_q.load_safe(short2(BD, params->qL_rem));
218 } else {
219 loader_q.load_unsafe();
220 }
221 loader_q.apply_inplace_op(ts);
222
223 // Init row reduction variables
224 constexpr short kRowsPT = decltype(Stile)::kRowsPerThread;
225
226 AccumType max_score[kRowsPT];
227 AccumType sum_score[kRowsPT] = {0};
228
229 // Init to -Inf
231 for (short i = 0; i < kRowsPT; ++i) {
232 max_score[i] = Limits<AccumType>::min;
233 }
234
235 int kb_lim = params->NK;
236
237 if (do_causal) {
238 int q_max = (tid.x + 1) * BQ + params->qL_off;
239 kb_lim = (q_max + BK - 1) / BK;
240 }
241
242 // Loop over KV seq length
243 for (int kb = 0; kb < kb_lim; kb++) {
244 // Load K block and apply scale
245 threadgroup_barrier(mem_flags::mem_threadgroup);
246 if (!align_K && kb == (params->NK_aligned)) {
247 loader_k.load_safe(short2(BD, params->kL_rem));
248 } else {
249 loader_k.load_unsafe();
250 }
251
252 // Do S = Q @ K.T
253 Stile.clear();
254
255 threadgroup_barrier(mem_flags::mem_threadgroup);
256
258 for (short dd = 0; dd < TD; dd++) {
259 simdgroup_barrier(mem_flags::mem_none);
260
261 Qtile.template load<T, 1, 1, LDQ_tgp, 1>(
262 &Qs[Qs_offset + dd * Qs_tile_stride]);
263 Ktile.template load<T, 1, 1, LDK_tgp, 1>(
264 &Ks[Ks_offset + dd * Ks_tile_stride]);
265
266 simdgroup_barrier(mem_flags::mem_none);
267
268 tile_matmad(Stile, Qtile, Ktile, Stile);
269 }
270
271 // Mask out length sequence
272 if (!align_K && kb == (params->NK_aligned)) {
273 using stile_t = decltype(Stile);
274 using selem_t = typename stile_t::elem_type;
275 constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
276
278 for (short i = 0; i < stile_t::kTileRows; i++) {
280 for (short j = 0; j < stile_t::kTileCols; j++) {
281 short col_pos = sn + (j * stile_t::kFragCols);
283 for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
284 if ((col_pos + jj) >= params->kL_rem) {
285 Stile.frag_at(i, j)[jj] = neg_inf;
286 }
287 }
288 }
289 }
290 }
291
292 // Mask out if causal
293 if (do_causal && kb >= (kb_lim - (BQ + BK - 1) / BK - int(!align_K))) {
294 using stile_t = decltype(Stile);
295 using selem_t = typename stile_t::elem_type;
296 constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
297
299 for (short i = 0; i < stile_t::kTileRows; i++) {
300 const int row_pos =
301 tid.x * BQ + params->qL_off + tm + sm + (i * stile_t::kFragRows);
303 for (short j = 0; j < stile_t::kTileCols; j++) {
304 const int col_pos = kb * BK + sn + (j * stile_t::kFragCols);
306 for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
307 if (row_pos < (col_pos + jj)) {
308 Stile.frag_at(i, j)[jj] = neg_inf;
309 }
310 }
311 }
312 }
313 }
314
315 // Other masking as needed
316 if (has_mask) {
317 using stile_t = decltype(Stile);
318 using selem_t = typename stile_t::elem_type;
319 constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
320
321 constexpr bool is_bool = is_same_v<MaskType, bool>;
322 using melem_t = typename metal::conditional_t<is_bool, bool, selem_t>;
323
324 using MMAFrag_mask_t = BaseMMAFrag<melem_t, kFragSize, kFragSize>;
325 using frag_t = typename MMAFrag_mask_t::frag_type;
326
328 for (short i = 0; i < stile_t::kTileRows; i++) {
329 const int row_pos = tid.x * BQ + tm + sm + (i * stile_t::kFragRows);
331 for (short j = 0; j < stile_t::kTileCols; j++) {
332 const int col_pos = kb * BK + sn + (j * stile_t::kFragCols);
333
334 frag_t mfrag;
335
336 MMAFrag_mask_t::load_safe(
337 mfrag,
338 mask,
339 int(mask_params->M_strides[2]),
340 Int<1>{},
341 params->qL,
342 params->kL,
343 row_pos,
344 col_pos);
345
347 for (short jj = 0; jj < stile_t::MMAFrag_t::kElemsPerFrag; jj++) {
348 if constexpr (is_bool) {
349 Stile.frag_at(i, j)[jj] =
350 mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf;
351 } else {
352 Stile.frag_at(i, j)[jj] += 1.44269504089 * selem_t(mfrag[jj]);
353 }
354 }
355 }
356 }
357 }
358
359 threadgroup_barrier(mem_flags::mem_threadgroup);
360
361 // Load V blocks
362 if (!align_K && kb == (params->NK_aligned)) {
363 loader_v.load_safe(short2(BD, params->kL_rem));
364 } else {
365 loader_v.load_unsafe();
366 }
367
368 // Do softmax
369
370 // Temp variables
371 AccumType new_max[kRowsPT];
372 AccumType factor[kRowsPT];
374 for (short i = 0; i < kRowsPT; ++i) {
375 new_max[i] = max_score[i];
376 }
377
378 // Row max
379 Stile.template row_reduce<MaxOp>(new_max);
380
381 // exp(Si - rowmax(Si))
382 Stile.template row_bin_op<ExpSubOp>(new_max);
383
384 // Factor exp(rowmax(Si) - rowmax(Si-1))
386 for (short i = 0; i < kRowsPT; ++i) {
387 factor[i] = fast::exp2(max_score[i] - new_max[i]);
388 }
389
390 // Save max for next iteration
392 for (short i = 0; i < kRowsPT; ++i) {
393 max_score[i] = new_max[i];
394 }
395
396 // Row Sum
397 AccumType sum_score_tmp[kRowsPT] = {0};
398 Stile.template row_reduce<SumOp>(sum_score_tmp);
399
400 // Update norm
402 for (short i = 0; i < kRowsPT; ++i) {
403 sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i];
404 }
405
406 // Update O
407 Otile.template row_bin_op<MulOp>(factor);
408
409 // Load V into registers
410 threadgroup_barrier(mem_flags::mem_threadgroup);
411
413 for (short iq = 0; iq < TQ; iq++) {
415 for (short id = 0; id < TD; id++) {
417 for (short ik = 0; ik < TK; ik++) {
418 if constexpr (BD == 128) {
419 simdgroup_barrier(mem_flags::mem_none);
420 }
421
422 const short kk = ik * kFragSize;
423 const short dd = id * kFragSize;
424
425 Vtile.template load<T, 1, 1, LDV_tgp, 1>(
426 &Vs[Vs_offset + kk * LDV_tgp + dd]);
427
428 if constexpr (BD == 128) {
429 simdgroup_barrier(mem_flags::mem_none);
430 }
431
432 MMAFrag_acc_t::mma(
433 Otile.frag_at(iq, id),
434 Stile.frag_at(iq, ik),
435 Vtile.frag_at(0, 0),
436 Otile.frag_at(iq, id));
437 }
438 }
439 }
440
441 // Prepare for next iteration
442 loader_k.next();
443 loader_v.next();
444 }
445
446 // Normalize output
447 Otile.template row_bin_op<DivOp>(sum_score);
448 threadgroup_barrier(mem_flags::mem_none);
449
450 // Store results
451 O += (tm + sm) * params->O_strides[2] + sn;
452
453 if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
454 auto dst_tile_dims = short2(BD - sn, params->qL_rem - (tm + sm));
455
456 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
457 return;
458
459 Otile.template store_safe<T, 1, 1>(O, params->O_strides[2], dst_tile_dims);
460 } else {
461 Otile.template store<T, 1, 1>(O, params->O_strides[2]);
462 }
463}
METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:232
Definition attn.h:19
METAL_FUNC void tile_matmad(thread MMATile< Dtype, M, N, MMAFragD > &D, thread MMATile< Atype, M, K, MMAFragA > &A, thread MMATile< Btype, K, N, MMAFragB > &B, thread MMATile< Ctype, M, N, MMAFragC > &C)
Definition mma.h:432
integral_constant< int, val > Int
Definition integral_constant.h:48
constant bool has_mask
Definition sdpa_vector.h:7
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
constant bool align_Q
Definition steel_attention.h:9
void attention(const device T *Q, const device T *K, const device T *V, device T *O, const constant AttnParams *params, const constant AttnMaskParams *mask_params, const device MaskType *mask, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 lid)
Definition steel_attention.h:77
constant bool align_K
Definition steel_attention.h:10
constant bool do_causal
Definition steel_attention.h:13
Definition steel_attention.h:60
static METAL_FUNC constexpr T apply(T x, T y)
Definition steel_attention.h:62
Definition steel_attention.h:53
static METAL_FUNC constexpr T apply(T x, T y)
Definition steel_attention.h:55
static const constant U min
Definition utils.h:25
Definition steel_attention.h:25
static METAL_FUNC constexpr T apply(T x, T y)
Definition steel_attention.h:27
Definition steel_attention.h:39
static METAL_FUNC constexpr T apply(T x, T y)
Definition steel_attention.h:41
Definition steel_attention.h:46
static METAL_FUNC constexpr T apply(T x, T y)
Definition steel_attention.h:48
Definition steel_attention.h:32
static METAL_FUNC constexpr T apply(T x, T y)
Definition steel_attention.h:34
Definition steel_attention.h:16
METAL_FUNC T apply(T x) const
Definition steel_attention.h:20
T scale
Definition steel_attention.h:17
METAL_FUNC TransformScale(T scale_)
Definition steel_attention.h:18
Definition params.h:39
Definition params.h:12
Definition mma.h:37
Definition loader.h:153
Definition mma.h:231
METAL_FUNC constexpr thread frag_type & frag_at(const short i, const short j)
Definition mma.h:264
METAL_FUNC constexpr void clear()
Definition mma.h:257