80      threadgroup T* As [[threadgroup(0)]],
 
   81      threadgroup T* Bs [[threadgroup(1)]],
 
   82      const int gemm_k_iterations,
 
   86      thread 
const short& tgp_bm,
 
   87      thread 
const short& tgp_bn,
 
   88      thread 
const short& lbk,
 
   93    short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
 
   95    short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
 
   97    for (
int k = 0; k < gemm_k_iterations; k++) {
 
   98      threadgroup_barrier(mem_flags::mem_threadgroup);
 
  101        loader_a.load_unsafe();
 
  103        loader_a.load_safe(tile_dims_A);
 
  107        loader_b.load_unsafe();
 
  109        loader_b.load_safe(tile_dims_B);
 
  112      threadgroup_barrier(mem_flags::mem_threadgroup);
 
  123      threadgroup_barrier(mem_flags::mem_threadgroup);
 
  125      short2 tile_dims_A_last =
 
  126          transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
 
  127      short2 tile_dims_B_last =
 
  128          transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
 
  130      loader_a.load_safe(tile_dims_A_last);
 
  131      loader_b.load_safe(tile_dims_B_last);
 
  133      threadgroup_barrier(mem_flags::mem_threadgroup);
 
 
  140  static METAL_FUNC 
void run(
 
  141      const device T* A [[buffer(0)]],
 
  142      const device T* B [[buffer(1)]],
 
  143      device U* D [[buffer(2)]],
 
  144      const constant 
GEMMParams* params [[buffer(3)]],
 
  145      threadgroup T* As [[threadgroup(0)]],
 
  146      threadgroup T* Bs [[threadgroup(1)]],
 
  147      uint simd_lane_id [[thread_index_in_simdgroup]],
 
  148      uint simd_group_id [[simdgroup_index_in_threadgroup]],
 
  149      uint3 tid [[threadgroup_position_in_grid]],
 
  150      uint3 lid [[thread_position_in_threadgroup]]) {
 
  154    const int tid_y = ((tid.y) << params->swizzle_log) +
 
  155        ((tid.x) & ((1 << params->swizzle_log) - 1));
 
  156    const int tid_x = (tid.x) >> params->swizzle_log;
 
  158    if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
 
  162    threadgroup_barrier(mem_flags::mem_none);
 
  165    const int c_row = tid_y * BM;
 
  166    const int c_col = tid_x * BN;
 
  167    const size_t c_row_long = size_t(c_row);
 
  168    const size_t c_col_long = size_t(c_col);
 
  170    A += transpose_a ? c_row_long : c_row_long * params->lda;
 
  171    B += transpose_b ? c_col_long * params->ldb : c_col_long;
 
  172    D += c_row_long * params->ldd + c_col_long;
 
  175    thread 
loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
 
  176    thread 
loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
 
  179    thread 
mma_t mma_op(simd_group_id, simd_lane_id);
 
  181    int gemm_k_iterations = params->gemm_k_iterations_aligned;
 
  186      for (
int k = 0; k < gemm_k_iterations; k++) {
 
  187        threadgroup_barrier(mem_flags::mem_threadgroup);
 
  189        loader_a.load_unsafe();
 
  190        loader_b.load_unsafe();
 
  192        threadgroup_barrier(mem_flags::mem_threadgroup);
 
  202      threadgroup_barrier(mem_flags::mem_none);
 
  206        int lbk = params->K - params->gemm_k_iterations_aligned * BK;
 
  207        short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
 
  208        short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
 
  210        loader_a.load_safe(tile_dims_A);
 
  211        loader_b.load_safe(tile_dims_B);
 
  213        threadgroup_barrier(mem_flags::mem_threadgroup);
 
  219      mma_op.store_result(D, params->ldd);
 
  226      short tgp_bm = 
min(BM, params->M - c_row);
 
  227      short tgp_bn = 
min(BN, params->N - c_col);
 
  228      short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
 
  230      if (tgp_bm == BM && tgp_bn == BN) {
 
  242        mma_op.store_result(D, params->ldd);
 
  245      } 
else if (tgp_bn == BN) {
 
  257        mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
 
  260      } 
else if (tgp_bm == BM) {
 
  272        mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
 
  287        mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));