77[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]]
void attention(
78 const device T* Q [[buffer(0)]],
79 const device T* K [[buffer(1)]],
80 const device T* V [[buffer(2)]],
81 device T* O [[buffer(3)]],
82 const constant
AttnParams* params [[buffer(4)]],
84 const device MaskType* mask [[buffer(6), function_constant(
has_mask)]],
85 uint simd_lane_id [[thread_index_in_simdgroup]],
86 uint simd_group_id [[simdgroup_index_in_threadgroup]],
87 uint3 tid [[threadgroup_position_in_grid]],
88 uint3 lid [[thread_position_in_threadgroup]]) {
94 ulong3 tidl{tid.x, tid.y, tid.z};
96 Q += tidl.z * params->Q_strides[0] +
97 tidl.y * params->Q_strides[1] +
98 tidl.x * BQ * params->Q_strides[2];
100 ulong kv_head_idx = int(tid.y) / params->gqa_factor;
101 K += tidl.z * params->K_strides[0] +
102 kv_head_idx * params->K_strides[1];
104 V += tidl.z * params->V_strides[0] +
105 kv_head_idx * params->V_strides[1];
107 O += tidl.z * params->O_strides[0] +
108 tidl.y * params->O_strides[1] +
109 tidl.x * BQ * params->O_strides[2];
112 mask += tidl.z * mask_params->M_strides[0] +
113 tidl.y * mask_params->M_strides[1];
117 constexpr short padQ = 16 /
sizeof(T);
118 constexpr short padK = 16 /
sizeof(T);
119 constexpr short padV = 16 /
sizeof(T);
121 constexpr short LDQ_tgp = BD + padQ;
122 constexpr short LDK_tgp = BK + padK;
123 constexpr short LDV_tgp = BD + padV;
125 constexpr short tgp_mem_0 = (BK + padK) * (BD);
126 constexpr short tgp_mem_1 = BK * (BD + padV);
127 constexpr short tgp_mem_s = tgp_mem_0 > tgp_mem_1 ? tgp_mem_0 : tgp_mem_1;
129 threadgroup T Q_smem[BQ * (BD + padQ)];
130 threadgroup T KV_smem[tgp_mem_s];
132 threadgroup T* Qs = Q_smem;
133 threadgroup T* Ks = KV_smem;
134 threadgroup T* Vs = KV_smem;
165 QBlockLoader loader_q(
166 Q, params->Q_strides[2], Qs, simd_group_id, simd_lane_id);
167 KBlockLoader loader_k(
168 K, params->K_strides[2], Ks, simd_group_id, simd_lane_id);
169 VBlockLoader loader_v(
170 V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
175 constexpr short kFragSize = 8;
178 constexpr int kNWarps = WM * WN;
180 BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0,
181 "Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.");
184 constexpr int TQ = BQ / (kNWarps * kFragSize);
186 constexpr int TK = BK / kFragSize;
188 constexpr int TD = BD / kFragSize;
190 static_assert(TQ == 1,
"Check TQ");
201 const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
202 const short sm = simd_coord.y;
203 const short sn = simd_coord.x;
204 const short tm = kFragSize * TQ * simd_group_id;
206 const short Qs_offset = (tm + sm) * LDQ_tgp + sn;
207 const short Ks_offset = sm * LDK_tgp + sn;
208 const short Vs_offset = sm * LDV_tgp + sn;
210 constexpr short Qs_tile_stride = kFragSize;
211 constexpr short Ks_tile_stride = kFragSize * LDK_tgp;
213 threadgroup_barrier(mem_flags::mem_threadgroup);
216 if (!
align_Q &&
int(tid.x) == (params->NQ_aligned)) {
217 loader_q.load_safe(short2(BD, params->qL_rem));
219 loader_q.load_unsafe();
221 loader_q.apply_inplace_op(ts);
224 constexpr short kRowsPT =
decltype(Stile)::kRowsPerThread;
226 AccumType max_score[kRowsPT];
227 AccumType sum_score[kRowsPT] = {0};
231 for (
short i = 0; i < kRowsPT; ++i) {
235 int kb_lim = params->NK;
238 int q_max = (tid.x + 1) * BQ + params->qL_off;
239 kb_lim = (q_max + BK - 1) / BK;
243 for (
int kb = 0; kb < kb_lim; kb++) {
245 threadgroup_barrier(mem_flags::mem_threadgroup);
246 if (!
align_K && kb == (params->NK_aligned)) {
247 loader_k.load_safe(short2(BD, params->kL_rem));
249 loader_k.load_unsafe();
255 threadgroup_barrier(mem_flags::mem_threadgroup);
258 for (
short dd = 0; dd < TD; dd++) {
259 simdgroup_barrier(mem_flags::mem_none);
261 Qtile.template load<T, 1, 1, LDQ_tgp, 1>(
262 &Qs[Qs_offset + dd * Qs_tile_stride]);
263 Ktile.template load<T, 1, 1, LDK_tgp, 1>(
264 &Ks[Ks_offset + dd * Ks_tile_stride]);
266 simdgroup_barrier(mem_flags::mem_none);
272 if (!
align_K && kb == (params->NK_aligned)) {
273 using stile_t =
decltype(Stile);
274 using selem_t =
typename stile_t::elem_type;
275 constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
278 for (
short i = 0; i < stile_t::kTileRows; i++) {
280 for (
short j = 0; j < stile_t::kTileCols; j++) {
281 short col_pos = sn + (j * stile_t::kFragCols);
283 for (
short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
284 if ((col_pos + jj) >= params->kL_rem) {
285 Stile.
frag_at(i, j)[jj] = neg_inf;
294 using stile_t =
decltype(Stile);
295 using selem_t =
typename stile_t::elem_type;
296 constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
299 for (
short i = 0; i < stile_t::kTileRows; i++) {
301 tid.x * BQ + params->qL_off + tm + sm + (i * stile_t::kFragRows);
303 for (
short j = 0; j < stile_t::kTileCols; j++) {
304 const int col_pos = kb * BK + sn + (j * stile_t::kFragCols);
306 for (
short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
307 if (row_pos < (col_pos + jj)) {
308 Stile.
frag_at(i, j)[jj] = neg_inf;
317 using stile_t =
decltype(Stile);
318 using selem_t =
typename stile_t::elem_type;
319 constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
321 constexpr bool is_bool = is_same_v<MaskType, bool>;
322 using melem_t =
typename metal::conditional_t<is_bool, bool, selem_t>;
325 using frag_t =
typename MMAFrag_mask_t::frag_type;
328 for (
short i = 0; i < stile_t::kTileRows; i++) {
329 const int row_pos = tid.x * BQ + tm + sm + (i * stile_t::kFragRows);
331 for (
short j = 0; j < stile_t::kTileCols; j++) {
332 const int col_pos = kb * BK + sn + (j * stile_t::kFragCols);
336 MMAFrag_mask_t::load_safe(
339 int(mask_params->M_strides[2]),
347 for (
short jj = 0; jj < stile_t::MMAFrag_t::kElemsPerFrag; jj++) {
348 if constexpr (is_bool) {
350 mfrag[jj] ? Stile.
frag_at(i, j)[jj] : neg_inf;
352 Stile.
frag_at(i, j)[jj] += 1.44269504089 * selem_t(mfrag[jj]);
359 threadgroup_barrier(mem_flags::mem_threadgroup);
362 if (!
align_K && kb == (params->NK_aligned)) {
363 loader_v.load_safe(short2(BD, params->kL_rem));
365 loader_v.load_unsafe();
371 AccumType new_max[kRowsPT];
372 AccumType factor[kRowsPT];
374 for (
short i = 0; i < kRowsPT; ++i) {
375 new_max[i] = max_score[i];
379 Stile.template row_reduce<MaxOp>(new_max);
382 Stile.template row_bin_op<ExpSubOp>(new_max);
386 for (
short i = 0; i < kRowsPT; ++i) {
387 factor[i] = fast::exp2(max_score[i] - new_max[i]);
392 for (
short i = 0; i < kRowsPT; ++i) {
393 max_score[i] = new_max[i];
397 AccumType sum_score_tmp[kRowsPT] = {0};
398 Stile.template row_reduce<SumOp>(sum_score_tmp);
402 for (
short i = 0; i < kRowsPT; ++i) {
403 sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i];
407 Otile.template row_bin_op<MulOp>(factor);
410 threadgroup_barrier(mem_flags::mem_threadgroup);
413 for (
short iq = 0; iq < TQ; iq++) {
415 for (
short id = 0;
id < TD;
id++) {
417 for (
short ik = 0; ik < TK; ik++) {
418 if constexpr (BD == 128) {
419 simdgroup_barrier(mem_flags::mem_none);
422 const short kk = ik * kFragSize;
423 const short dd =
id * kFragSize;
425 Vtile.template load<T, 1, 1, LDV_tgp, 1>(
426 &Vs[Vs_offset + kk * LDV_tgp + dd]);
428 if constexpr (BD == 128) {
429 simdgroup_barrier(mem_flags::mem_none);
447 Otile.template row_bin_op<DivOp>(sum_score);
448 threadgroup_barrier(mem_flags::mem_none);
451 O += (tm + sm) * params->O_strides[2] + sn;
453 if (!
align_Q &&
int(tid.x) == (params->NQ_aligned)) {
454 auto dst_tile_dims = short2(BD - sn, params->qL_rem - (tm + sm));
456 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
459 Otile.template store_safe<T, 1, 1>(O, params->O_strides[2], dst_tile_dims);
461 Otile.template store<T, 1, 1>(O, params->O_strides[2]);