73[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]]
void attention(
74 const device T* Q [[buffer(0)]],
75 const device T* K [[buffer(1)]],
76 const device T* V [[buffer(2)]],
77 device T* O [[buffer(3)]],
78 const constant
AttnParams* params [[buffer(4)]],
79 uint simd_lane_id [[thread_index_in_simdgroup]],
80 uint simd_group_id [[simdgroup_index_in_threadgroup]],
81 uint3 tid [[threadgroup_position_in_grid]],
82 uint3 lid [[thread_position_in_threadgroup]]) {
88 ulong3 tidl{tid.x, tid.y, tid.z};
90 Q += tidl.z * params->Q_strides[0] +
91 tidl.y * params->Q_strides[1] +
92 tidl.x * BQ * params->Q_strides[2];
94 ulong kv_head_idx = int(tid.y) / params->gqa_factor;
95 K += tidl.z * params->K_strides[0] +
96 kv_head_idx * params->K_strides[1];
98 V += tidl.z * params->V_strides[0] +
99 kv_head_idx * params->V_strides[1];
101 O += tidl.z * params->O_strides[0] +
102 tidl.y * params->O_strides[1] +
103 tidl.x * BQ * params->O_strides[2];
106 constexpr short padQ = 16 /
sizeof(T);
107 constexpr short padK = 16 /
sizeof(T);
108 constexpr short padV = 16 /
sizeof(T);
110 constexpr short LDQ_tgp = BD + padQ;
111 constexpr short LDK_tgp = BK + padK;
112 constexpr short LDV_tgp = BD + padV;
114 constexpr short tgp_mem_0 = (BK + padK) * (BD);
115 constexpr short tgp_mem_1 = BK * (BD + padV);
116 constexpr short tgp_mem_s = tgp_mem_0 > tgp_mem_1 ? tgp_mem_0 : tgp_mem_1;
118 threadgroup T Q_smem[BQ * (BD + padQ)];
119 threadgroup T KV_smem[tgp_mem_s];
121 threadgroup T* Qs = Q_smem;
122 threadgroup T* Ks = KV_smem;
123 threadgroup T* Vs = KV_smem;
154 QBlockLoader loader_q(
155 Q, params->Q_strides[2], Qs, simd_group_id, simd_lane_id);
156 KBlockLoader loader_k(
157 K, params->K_strides[2], Ks, simd_group_id, simd_lane_id);
158 VBlockLoader loader_v(
159 V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
164 constexpr short kFragSize = 8;
167 constexpr int kNWarps = WM * WN;
169 BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0,
170 "Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.");
173 constexpr int TQ = BQ / (kNWarps * kFragSize);
175 constexpr int TK = BK / kFragSize;
177 constexpr int TD = BD / kFragSize;
179 static_assert(TQ == 1,
"Check TQ");
190 const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
191 const short sm = simd_coord.y;
192 const short sn = simd_coord.x;
193 const short tm = kFragSize * TQ * simd_group_id;
195 const short Qs_offset = (tm + sm) * LDQ_tgp + sn;
196 const short Ks_offset = sm * LDK_tgp + sn;
197 const short Vs_offset = sm * LDV_tgp + sn;
199 constexpr short Qs_tile_stride = kFragSize;
200 constexpr short Ks_tile_stride = kFragSize * LDK_tgp;
202 threadgroup_barrier(mem_flags::mem_threadgroup);
205 if (!
align_Q &&
int(tid.x) == (params->NQ_aligned)) {
206 loader_q.load_safe(short2(BD, params->qL - params->NQ_aligned * BQ));
208 loader_q.load_unsafe();
210 loader_q.apply_inplace_op(ts);
213 constexpr short kRowsPT =
decltype(Stile)::kRowsPerThread;
215 AccumType max_score[kRowsPT];
216 AccumType sum_score[kRowsPT] = {0};
220 for (
short i = 0; i < kRowsPT; ++i) {
225 for (
int kb = 0; kb < params->NK; kb++) {
227 threadgroup_barrier(mem_flags::mem_threadgroup);
228 if (!
align_K && kb == (params->NK_aligned)) {
229 loader_k.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
231 loader_k.load_unsafe();
237 threadgroup_barrier(mem_flags::mem_threadgroup);
240 for (
short dd = 0; dd < TD; dd++) {
241 simdgroup_barrier(mem_flags::mem_none);
243 Qtile.template load<T, 1, 1, LDQ_tgp, 1>(
244 &Qs[Qs_offset + dd * Qs_tile_stride]);
245 Ktile.template load<T, 1, 1, LDK_tgp, 1>(
246 &Ks[Ks_offset + dd * Ks_tile_stride]);
248 simdgroup_barrier(mem_flags::mem_none);
254 if (!
align_K && kb == (params->NK_aligned)) {
255 using stile_t =
decltype(Stile);
256 using selem_t =
typename stile_t::elem_type;
257 constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
258 const short lim = params->kL - params->NK_aligned * BK;
261 for (
short i = 0; i < stile_t::kTileRows; i++) {
263 for (
short j = 0; j < stile_t::kTileCols; j++) {
264 short col_pos = sn + (j * stile_t::kFragCols);
266 for (
short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
267 if ((col_pos + jj) >= lim) {
268 Stile.
frag_at(i, j)[jj] = neg_inf;
275 threadgroup_barrier(mem_flags::mem_threadgroup);
278 if (!
align_K && kb == (params->NK_aligned)) {
279 loader_v.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
281 loader_v.load_unsafe();
287 AccumType new_max[kRowsPT];
288 AccumType factor[kRowsPT];
290 for (
short i = 0; i < kRowsPT; ++i) {
291 new_max[i] = max_score[i];
295 Stile.template row_reduce<MaxOp>(new_max);
298 Stile.template row_bin_op<ExpSubOp>(new_max);
302 for (
short i = 0; i < kRowsPT; ++i) {
303 factor[i] = fast::exp2(max_score[i] - new_max[i]);
308 for (
short i = 0; i < kRowsPT; ++i) {
309 max_score[i] = new_max[i];
313 AccumType sum_score_tmp[kRowsPT] = {0};
314 Stile.template row_reduce<SumOp>(sum_score_tmp);
318 for (
short i = 0; i < kRowsPT; ++i) {
319 sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i];
323 Otile.template row_bin_op<MulOp>(factor);
326 threadgroup_barrier(mem_flags::mem_threadgroup);
329 for (
short iq = 0; iq < TQ; iq++) {
331 for (
short id = 0;
id < TD;
id++) {
333 for (
short ik = 0; ik < TK; ik++) {
334 if constexpr (BD == 128) {
335 simdgroup_barrier(mem_flags::mem_none);
338 const short kk = ik * kFragSize;
339 const short dd =
id * kFragSize;
341 Vtile.template load<T, 1, 1, LDV_tgp, 1>(
342 &Vs[Vs_offset + kk * LDV_tgp + dd]);
344 if constexpr (BD == 128) {
345 simdgroup_barrier(mem_flags::mem_none);
363 Otile.template row_bin_op<DivOp>(sum_score);
364 threadgroup_barrier(mem_flags::mem_none);
367 O += (tm + sm) * params->O_strides[2] + sn;
369 if (!
align_Q &&
int(tid.x) == (params->NQ_aligned)) {
371 short2(BD - sn, params->qL - BQ * params->NQ_aligned - (tm + sm));
373 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
376 Otile.template store_safe<T, 1, 1>(O, params->O_strides[2], dst_tile_dims);
378 Otile.template store<T, 1, 1>(O, params->O_strides[2]);