21[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] 
void gemm_splitk(
 
   22    const device T* A [[buffer(0)]],
 
   23    const device T* B [[buffer(1)]],
 
   24    device U* C [[buffer(2)]],
 
   26    uint simd_lane_id [[thread_index_in_simdgroup]],
 
   27    uint simd_group_id [[simdgroup_index_in_threadgroup]],
 
   28    uint3 tid [[threadgroup_position_in_grid]],
 
   29    uint3 lid [[thread_position_in_threadgroup]]) {
 
   44  using loader_a_t = 
typename gemm_kernel::loader_a_t;
 
   45  using loader_b_t = 
typename gemm_kernel::loader_b_t;
 
   46  using mma_t = 
typename gemm_kernel::mma_t;
 
   48  threadgroup T As[gemm_kernel::tgp_mem_size_a];
 
   49  threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
 
   51  const int tid_x = tid.x;
 
   52  const int tid_y = tid.y;
 
   53  const int tid_z = tid.z;
 
   55  if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
 
   60  const int c_row = tid_y * BM;
 
   61  const int c_col = tid_x * BN;
 
   62  const int k_start = params->split_k_partition_size * tid_z;
 
   64  const size_t c_row_long = size_t(c_row);
 
   65  const size_t c_col_long = size_t(c_col);
 
   66  const size_t k_start_long = size_t(k_start);
 
   68  A += transpose_a ? (c_row_long + k_start_long * params->lda)
 
   69                   : (k_start_long + c_row_long * params->lda);
 
   70  B += transpose_b ? (k_start_long + c_col_long * params->ldb)
 
   71                   : (c_col_long + k_start_long * params->ldb);
 
   72  C += (size_t(params->split_k_partition_stride) * tid_z) +
 
   73      (c_row_long * params->ldc + c_col_long);
 
   76  thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
 
   77  thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
 
   80  thread mma_t mma_op(simd_group_id, simd_lane_id);
 
   82  int gemm_k_iterations = params->gemm_k_iterations_aligned;
 
   84  short tgp_bm = min(BM, params->M - c_row);
 
   85  short tgp_bn = min(BN, params->N - c_col);
 
   86  short leftover_bk = params->K % BK;
 
   88  if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
 
   89    gemm_kernel::gemm_loop(
 
  100  } 
else if (tgp_bn == BN) {
 
  101    gemm_kernel::gemm_loop(
 
  112  } 
else if (tgp_bm == BM) {
 
  113    gemm_kernel::gemm_loop(
 
  125    gemm_kernel::gemm_loop(
 
  138  threadgroup_barrier(mem_flags::mem_threadgroup);
 
  140  if ((tid_z + 1) == (params->split_k_partitions)) {
 
  141    int gemm_k_iter_remaining =
 
  142        (params->K - (k_start + params->split_k_partition_size)) / BK;
 
  143    if (!K_aligned || gemm_k_iter_remaining > 0)
 
  144      gemm_kernel::gemm_loop(
 
  147          gemm_k_iter_remaining,
 
  157  if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
 
  158    mma_op.store_result(C, params->ldc);
 
  160    mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
 
 
  173    const device AccT* C_split [[buffer(0)]],
 
  174    device OutT* D [[buffer(1)]],
 
  175    const constant 
int& k_partitions [[buffer(2)]],
 
  176    const constant 
int& partition_stride [[buffer(3)]],
 
  177    const constant 
int& ldd [[buffer(4)]],
 
  178    uint2 gid [[thread_position_in_grid]]) {
 
  180  D += gid.x + gid.y * size_t(ldd);
 
  181  C_split += gid.x + gid.y * size_t(ldd);
 
  186  for (
int i = 0; i < k_partitions; i++) {
 
  187    out += C_split[offset];
 
  188    offset += partition_stride;
 
  192  D[0] = Epilogue::apply(out);
 
 
  200    const device AccT* C_split [[buffer(0)]],
 
  201    device OutT* D [[buffer(1)]],
 
  202    const constant 
int& k_partitions [[buffer(2)]],
 
  203    const constant 
int& partition_stride [[buffer(3)]],
 
  204    const constant 
int& ldd [[buffer(4)]],
 
  205    const device OutT* C [[buffer(5)]],
 
  206    const constant 
int& ldc [[buffer(6)]],
 
  207    const constant 
int& fdc [[buffer(7)]],
 
  208    const constant 
float& alpha [[buffer(8)]],
 
  209    const constant 
float& beta [[buffer(9)]],
 
  210    uint2 gid [[thread_position_in_grid]]) {
 
  212  C += gid.x * size_t(fdc) + gid.y * size_t(ldc);
 
  213  D += gid.x + gid.y * size_t(ldd);
 
  214  C_split += gid.x + gid.y * size_t(ldd);
 
  219  for (
int i = 0; i < k_partitions; i++) {
 
  220    out += C_split[offset];
 
  221    offset += partition_stride;
 
  225  Epilogue 
op(alpha, beta);
 
  226  D[0] = 
op.apply(out, *C);
 
 
void gemm_splitk(const device T *A, const device T *B, device U *C, const constant GEMMSpiltKParams *params, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 lid)
Definition steel_gemm_splitk.h:21
 
void gemm_splitk_accum(const device AccT *C_split, device OutT *D, const constant int &k_partitions, const constant int &partition_stride, const constant int &ldd, uint2 gid)
Definition steel_gemm_splitk.h:172
 
void gemm_splitk_accum_axpby(const device AccT *C_split, device OutT *D, const constant int &k_partitions, const constant int &partition_stride, const constant int &ldd, const device OutT *C, const constant int &ldc, const constant int &fdc, const constant float &alpha, const constant float &beta, uint2 gid)
Definition steel_gemm_splitk.h:199