77[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] 
void attention(
 
   78    const device T* Q [[buffer(0)]],
 
   79    const device T* K [[buffer(1)]],
 
   80    const device T* V [[buffer(2)]],
 
   81    device T* O [[buffer(3)]],
 
   82    const constant 
AttnParams* params [[buffer(4)]],
 
   84    const device MaskType* mask [[buffer(6), function_constant(
has_mask)]],
 
   85    uint simd_lane_id [[thread_index_in_simdgroup]],
 
   86    uint simd_group_id [[simdgroup_index_in_threadgroup]],
 
   87    uint3 tid [[threadgroup_position_in_grid]],
 
   88    uint3 lid [[thread_position_in_threadgroup]]) { 
 
   94  ulong3 tidl{tid.x, tid.y, tid.z};
 
   96  Q += tidl.z * params->Q_strides[0] + 
 
   97      tidl.y * params->Q_strides[1] + 
 
   98      tidl.x * BQ * params->Q_strides[2]; 
 
  100  ulong kv_head_idx = int(tid.y) / params->gqa_factor;
 
  101  K += tidl.z * params->K_strides[0] + 
 
  102      kv_head_idx * params->K_strides[1]; 
 
  104  V += tidl.z * params->V_strides[0] + 
 
  105      kv_head_idx * params->V_strides[1]; 
 
  107  O += tidl.z * params->O_strides[0] + 
 
  108      tidl.y * params->O_strides[1] + 
 
  109      tidl.x * BQ * params->O_strides[2]; 
 
  112    mask += tidl.z * mask_params->M_strides[0] + 
 
  113        tidl.y * mask_params->M_strides[1]; 
 
  117  constexpr short padQ = 16 / 
sizeof(T);
 
  118  constexpr short padK = 16 / 
sizeof(T);
 
  119  constexpr short padV = 16 / 
sizeof(T);
 
  121  constexpr short LDQ_tgp = BD + padQ;
 
  122  constexpr short LDK_tgp = BK + padK;
 
  123  constexpr short LDV_tgp = BD + padV;
 
  125  constexpr short tgp_mem_0 = (BK + padK) * (BD);
 
  126  constexpr short tgp_mem_1 = BK * (BD + padV);
 
  127  constexpr short tgp_mem_s = tgp_mem_0 > tgp_mem_1 ? tgp_mem_0 : tgp_mem_1;
 
  129  threadgroup T Q_smem[BQ * (BD + padQ)];
 
  130  threadgroup T KV_smem[tgp_mem_s];
 
  132  threadgroup T* Qs = Q_smem;
 
  133  threadgroup T* Ks = KV_smem;
 
  134  threadgroup T* Vs = KV_smem;
 
  165  QBlockLoader loader_q(
 
  166      Q, params->Q_strides[2], Qs, simd_group_id, simd_lane_id);
 
  167  KBlockLoader loader_k(
 
  168      K, params->K_strides[2], Ks, simd_group_id, simd_lane_id);
 
  169  VBlockLoader loader_v(
 
  170      V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
 
  175  constexpr short kFragSize = 8; 
 
  178  constexpr int kNWarps = WM * WN;
 
  180      BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0,
 
  181      "Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.");
 
  184  constexpr int TQ = BQ / (kNWarps * kFragSize);
 
  186  constexpr int TK = BK / kFragSize;
 
  188  constexpr int TD = BD / kFragSize;
 
  190  static_assert(TQ == 1, 
"Check TQ");
 
  201  const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
 
  202  const short sm = simd_coord.y;
 
  203  const short sn = simd_coord.x;
 
  204  const short tm = kFragSize * TQ * simd_group_id;
 
  206  const short Qs_offset = (tm + sm) * LDQ_tgp + sn;
 
  207  const short Ks_offset = sm * LDK_tgp + sn;
 
  208  const short Vs_offset = sm * LDV_tgp + sn;
 
  210  constexpr short Qs_tile_stride = kFragSize;
 
  211  constexpr short Ks_tile_stride = kFragSize * LDK_tgp;
 
  213  threadgroup_barrier(mem_flags::mem_threadgroup);
 
  216  if (!
align_Q && 
int(tid.x) == (params->NQ_aligned)) {
 
  217    loader_q.load_safe(short2(BD, params->qL_rem));
 
  219    loader_q.load_unsafe();
 
  221  loader_q.apply_inplace_op(ts);
 
  224  constexpr short kRowsPT = 
decltype(Stile)::kRowsPerThread;
 
  226  AccumType max_score[kRowsPT];
 
  227  AccumType sum_score[kRowsPT] = {0};
 
  231  for (
short i = 0; i < kRowsPT; ++i) {
 
  235  int kb_lim = params->NK;
 
  238    int q_max = (tid.x + 1) * BQ + params->qL_off;
 
  239    kb_lim = (q_max + BK - 1) / BK;
 
  243  for (
int kb = 0; kb < kb_lim; kb++) {
 
  245    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  246    if (!
align_K && kb == (params->NK_aligned)) {
 
  247      loader_k.load_safe(short2(BD, params->kL_rem));
 
  249      loader_k.load_unsafe();
 
  255    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  258    for (
short dd = 0; dd < TD; dd++) {
 
  259      simdgroup_barrier(mem_flags::mem_none);
 
  261      Qtile.template load<T, 1, 1, LDQ_tgp, 1>(
 
  262          &Qs[Qs_offset + dd * Qs_tile_stride]);
 
  263      Ktile.template load<T, 1, 1, LDK_tgp, 1>(
 
  264          &Ks[Ks_offset + dd * Ks_tile_stride]);
 
  266      simdgroup_barrier(mem_flags::mem_none);
 
  272    if (!
align_K && kb == (params->NK_aligned)) {
 
  273      using stile_t = 
decltype(Stile);
 
  274      using selem_t = 
typename stile_t::elem_type;
 
  275      constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
 
  278      for (
short i = 0; i < stile_t::kTileRows; i++) {
 
  280        for (
short j = 0; j < stile_t::kTileCols; j++) {
 
  281          short col_pos = sn + (j * stile_t::kFragCols);
 
  283          for (
short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
 
  284            if ((col_pos + jj) >= params->kL_rem) {
 
  285              Stile.
frag_at(i, j)[jj] = neg_inf;
 
  294      using stile_t = 
decltype(Stile);
 
  295      using selem_t = 
typename stile_t::elem_type;
 
  296      constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
 
  299      for (
short i = 0; i < stile_t::kTileRows; i++) {
 
  301            tid.x * BQ + params->qL_off + tm + sm + (i * stile_t::kFragRows);
 
  303        for (
short j = 0; j < stile_t::kTileCols; j++) {
 
  304          const int col_pos = kb * BK + sn + (j * stile_t::kFragCols);
 
  306          for (
short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
 
  307            if (row_pos < (col_pos + jj)) {
 
  308              Stile.
frag_at(i, j)[jj] = neg_inf;
 
  317      using stile_t = 
decltype(Stile);
 
  318      using selem_t = 
typename stile_t::elem_type;
 
  319      constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
 
  321      constexpr bool is_bool = is_same_v<MaskType, bool>;
 
  322      using melem_t = 
typename metal::conditional_t<is_bool, bool, selem_t>;
 
  325      using frag_t = 
typename MMAFrag_mask_t::frag_type;
 
  328      for (
short i = 0; i < stile_t::kTileRows; i++) {
 
  329        const int row_pos = tid.x * BQ + tm + sm + (i * stile_t::kFragRows);
 
  331        for (
short j = 0; j < stile_t::kTileCols; j++) {
 
  332          const int col_pos = kb * BK + sn + (j * stile_t::kFragCols);
 
  336          MMAFrag_mask_t::load_safe(
 
  339              int(mask_params->M_strides[2]),
 
  347          for (
short jj = 0; jj < stile_t::MMAFrag_t::kElemsPerFrag; jj++) {
 
  348            if constexpr (is_bool) {
 
  350                  mfrag[jj] ? Stile.
frag_at(i, j)[jj] : neg_inf;
 
  352              Stile.
frag_at(i, j)[jj] += 1.44269504089 * selem_t(mfrag[jj]);
 
  359    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  362    if (!
align_K && kb == (params->NK_aligned)) {
 
  363      loader_v.load_safe(short2(BD, params->kL_rem));
 
  365      loader_v.load_unsafe();
 
  371    AccumType new_max[kRowsPT];
 
  372    AccumType factor[kRowsPT];
 
  374    for (
short i = 0; i < kRowsPT; ++i) {
 
  375      new_max[i] = max_score[i];
 
  379    Stile.template row_reduce<MaxOp>(new_max);
 
  382    Stile.template row_bin_op<ExpSubOp>(new_max);
 
  386    for (
short i = 0; i < kRowsPT; ++i) {
 
  387      factor[i] = fast::exp2(max_score[i] - new_max[i]);
 
  392    for (
short i = 0; i < kRowsPT; ++i) {
 
  393      max_score[i] = new_max[i];
 
  397    AccumType sum_score_tmp[kRowsPT] = {0};
 
  398    Stile.template row_reduce<SumOp>(sum_score_tmp);
 
  402    for (
short i = 0; i < kRowsPT; ++i) {
 
  403      sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i];
 
  407    Otile.template row_bin_op<MulOp>(factor);
 
  410    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  413    for (
short iq = 0; iq < TQ; iq++) {
 
  415      for (
short id = 0; 
id < TD; 
id++) {
 
  417        for (
short ik = 0; ik < TK; ik++) {
 
  418          if constexpr (BD == 128) {
 
  419            simdgroup_barrier(mem_flags::mem_none);
 
  422          const short kk = ik * kFragSize;
 
  423          const short dd = 
id * kFragSize;
 
  425          Vtile.template load<T, 1, 1, LDV_tgp, 1>(
 
  426              &Vs[Vs_offset + kk * LDV_tgp + dd]);
 
  428          if constexpr (BD == 128) {
 
  429            simdgroup_barrier(mem_flags::mem_none);
 
  447  Otile.template row_bin_op<DivOp>(sum_score);
 
  448  threadgroup_barrier(mem_flags::mem_none);
 
  451  O += (tm + sm) * params->O_strides[2] + sn;
 
  453  if (!
align_Q && 
int(tid.x) == (params->NQ_aligned)) {
 
  454    auto dst_tile_dims = short2(BD - sn, params->qL_rem - (tm + sm));
 
  456    if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
 
  459    Otile.template store_safe<T, 1, 1>(O, params->O_strides[2], dst_tile_dims);
 
  461    Otile.template store<T, 1, 1>(O, params->O_strides[2]);