33[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] 
void gemm(
 
   34    const device T* A [[buffer(0)]],
 
   35    const device T* B [[buffer(1)]],
 
   36    const device T* C [[buffer(2), function_constant(
use_out_source)]],
 
   37    device T* D [[buffer(3)]],
 
   38    const constant 
GEMMParams* params [[buffer(4)]],
 
   40    const constant 
int* batch_shape [[buffer(6)]],
 
   41    const constant 
size_t* batch_strides [[buffer(7)]],
 
   42    const constant uint32_t* lhs_indices [[buffer(10), function_constant(
do_gather)]],
 
   43    const constant uint32_t* rhs_indices [[buffer(11), function_constant(
do_gather)]],
 
   44    const constant uint32_t* C_indices [[buffer(12), function_constant(
gather_bias)]],
 
   45    const constant 
int* operand_shape [[buffer(13), function_constant(
do_gather)]],
 
   46    const constant 
size_t* operand_strides [[buffer(14), function_constant(
do_gather)]],
 
   47    const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(
do_gather)]],
 
   48    uint simd_lane_id [[thread_index_in_simdgroup]],
 
   49    uint simd_group_id [[simdgroup_index_in_threadgroup]],
 
   50    uint3 tid [[threadgroup_position_in_grid]],
 
   51    uint3 lid [[thread_position_in_threadgroup]]) { 
 
   69  using loader_a_t = 
typename gemm_kernel::loader_a_t;
 
   70  using loader_b_t = 
typename gemm_kernel::loader_b_t;
 
   71  using mma_t = 
typename gemm_kernel::mma_t;
 
   74  const int tid_y = ((tid.y) << params->swizzle_log) +
 
   75      ((tid.x) & ((1 << params->swizzle_log) - 1));
 
   76  const int tid_x = (tid.x) >> params->swizzle_log;
 
   79  if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
 
   88    uint32_t indx_A, indx_B, indx_C;
 
   91      const constant 
size_t* indx_A_bstrides = batch_strides;
 
   92      const constant 
size_t* indx_B_bstrides =
 
   93          batch_strides + params->batch_ndim;
 
  101      indx_A = lhs_indices[indx_offsets.x];
 
  102      indx_B = rhs_indices[indx_offsets.y];
 
  105        const constant 
size_t* indx_C_bstrides =
 
  106            indx_B_bstrides + params->batch_ndim;
 
  108            tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
 
  109        indx_C = C_indices[indx_offset_C];
 
  112      indx_A = lhs_indices[params->batch_stride_a * tid.z];
 
  113      indx_B = rhs_indices[params->batch_stride_b * tid.z];
 
  116        indx_C = C_indices[addmm_params->batch_stride_c * tid.z];
 
  121    int batch_ndim_A = operand_batch_ndim.x;
 
  122    const constant 
int* batch_shape_A = operand_shape;
 
  123    const constant 
size_t* batch_strides_A = operand_strides;
 
  124    A += 
elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
 
  126    int batch_ndim_B = operand_batch_ndim.y;
 
  127    const constant 
int* batch_shape_B = batch_shape_A + batch_ndim_A;
 
  128    const constant 
size_t* batch_strides_B = batch_strides_A + batch_ndim_A;
 
  129    B += 
elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
 
  132      int batch_ndim_C = operand_batch_ndim.z;
 
  133      const constant 
int* batch_shape_C = batch_shape_B + batch_ndim_B;
 
  134      const constant 
size_t* batch_strides_C = batch_strides_B + batch_ndim_B;
 
  135      C += 
elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
 
  143      const constant 
size_t* A_bstrides = batch_strides;
 
  144      const constant 
size_t* B_bstrides = batch_strides + params->batch_ndim;
 
  147          tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
 
  149      A += batch_offsets.x;
 
  150      B += batch_offsets.y;
 
  153        const constant 
size_t* C_bstrides = B_bstrides + params->batch_ndim;
 
  154        C += 
elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
 
  157      A += params->batch_stride_a * tid.z;
 
  158      B += params->batch_stride_b * tid.z;
 
  161        C += addmm_params->batch_stride_c * tid.z;
 
  166  D += params->batch_stride_d * tid.z;
 
  169  threadgroup T As[gemm_kernel::tgp_mem_size_a];
 
  170  threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
 
  172  threadgroup_barrier(mem_flags::mem_none);
 
  175  const int c_row = tid_y * BM;
 
  176  const int c_col = tid_x * BN;
 
  177  const size_t c_row_long = size_t(c_row);
 
  178  const size_t c_col_long = size_t(c_col);
 
  180  A += transpose_a ? c_row_long : c_row_long * params->lda;
 
  181  B += transpose_b ? c_col_long * params->ldb : c_col_long;
 
  182  D += c_row_long * params->ldd + c_col_long;
 
  185    C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
 
  189  thread mma_t mma_op(simd_group_id, simd_lane_id);
 
  192  thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
 
  193  thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
 
  196  const short tgp_bm = 
align_M ? BM : short(min(BM, params->M - c_row));
 
  197  const short tgp_bn = 
align_N ? BN : short(min(BN, params->N - c_col));
 
  200  int gemm_k_iterations = params->gemm_k_iterations_aligned;
 
  204    const int k_last = params->gemm_k_iterations_aligned * BK;
 
  205    const int k_remain = params->K - k_last;
 
  206    const size_t k_jump_a =
 
  207        transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
 
  208    const size_t k_jump_b =
 
  209        transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
 
  212    loader_a.src += k_jump_a;
 
  213    loader_b.src += k_jump_b;
 
  216    const short2 tile_dims_A =
 
  217        transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
 
  218    const short2 tile_dims_B =
 
  219        transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
 
  221    loader_a.load_safe(tile_dims_A);
 
  222    loader_b.load_safe(tile_dims_B);
 
  224    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  230    loader_a.src -= k_jump_a;
 
  231    loader_b.src -= k_jump_b;
 
  235      addmm_params->alpha, addmm_params->beta);
 
  237      addmm_params->alpha, addmm_params->beta);
 
  243    for (
int k = 0; k < gemm_k_iterations; k++) {
 
  244      threadgroup_barrier(mem_flags::mem_threadgroup);
 
  246      loader_a.load_unsafe();
 
  247      loader_b.load_unsafe();
 
  249      threadgroup_barrier(mem_flags::mem_threadgroup);
 
  259    threadgroup_barrier(mem_flags::mem_none);
 
  264        mma_op.apply_epilogue(
 
  265            C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
 
  267        mma_op.apply_epilogue(
 
  268            C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
 
  273    return mma_op.store_result(D, params->ldd);
 
  279    const int leftover_bk = 0;
 
  283      gemm_kernel::gemm_loop(
 
  298          mma_op.apply_epilogue(
 
  299              C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
 
  301          mma_op.apply_epilogue(
 
  302              C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
 
  307      return mma_op.store_result(D, params->ldd);
 
  309    } 
else if (
align_N || tgp_bn == BN) {
 
  310      gemm_kernel::gemm_loop(
 
  325          mma_op.apply_epilogue_safe(
 
  329              short2(tgp_bn, tgp_bm),
 
  332          mma_op.apply_epilogue_safe(
 
  336              short2(tgp_bn, tgp_bm),
 
  342      return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
 
  344    } 
else if (
align_M || tgp_bm == BM) {
 
  345      gemm_kernel::gemm_loop(
 
  360          mma_op.apply_epilogue_safe(
 
  364              short2(tgp_bn, tgp_bm),
 
  367          mma_op.apply_epilogue_safe(
 
  371              short2(tgp_bn, tgp_bm),
 
  377      return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
 
  380      gemm_kernel::gemm_loop(
 
  395          mma_op.apply_epilogue_safe(
 
  399              short2(tgp_bn, tgp_bm),
 
  402          mma_op.apply_epilogue_safe(
 
  406              short2(tgp_bn, tgp_bm),
 
  412      return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));