18    const device T* A [[buffer(0)]],
 
   19    const device T* B [[buffer(1)]],
 
   20    device T* C [[buffer(2)]],
 
   22    const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
 
   23    uint3 tid [[threadgroup_position_in_grid]],
 
   24    uint3 lid [[thread_position_in_threadgroup]],
 
   25    uint simd_gid [[simdgroup_index_in_threadgroup]],
 
   26    uint simd_lid [[thread_index_in_simdgroup]]) {
 
   31  constexpr bool transpose_a = 
false;
 
   32  constexpr bool transpose_b = 
true;
 
   33  constexpr short tgp_padding_a = 16 / 
sizeof(T);
 
   34  constexpr short tgp_padding_b = 16 / 
sizeof(T);
 
   36  constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
 
   37  constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
 
   38  constexpr short shape_a_rows = (transpose_a ? BK : BM);
 
   39  constexpr short shape_b_rows = (transpose_b ? BN : BK);
 
   40  constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
 
   41  constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
 
   43  constexpr short tgp_size = WM * WN * 32;
 
   47  using loader_a_t = 
typename metal::conditional_t<
 
   49      N_CHANNELS != 0 && N_CHANNELS <= 4,
 
   62      typename metal::conditional_t<
 
   85  using loader_b_t = 
typename metal::conditional_t<
 
   87      N_CHANNELS != 0 && N_CHANNELS <= 4,
 
  115  threadgroup T As[tgp_mem_size_a];
 
  116  threadgroup T Bs[tgp_mem_size_b];
 
  118  const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
 
  119      ((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
 
  120  const int tid_x = (tid.x) >> gemm_params->swizzle_log;
 
  122  if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
 
  126  const int c_row = tid_y * BM;
 
  127  const int c_col = tid_x * BN;
 
  128  const int K = gemm_params->K;
 
  129  const int N = gemm_params->N;
 
  130  const int C_per_group = params->C / params->groups;
 
  133  A += tid.z * C_per_group;
 
  138  C += c_row * (N * params->groups) + c_col;
 
  140  const int2 offsets_a(0, c_row);
 
  141  const int2 offsets_b(0, c_col);
 
  145      A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
 
  147      B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
 
  150  mma_t mma_op(simd_gid, simd_lid);
 
  152  int gemm_k_iterations = gemm_params->gemm_k_iterations;
 
  153  for (
int k = 0; k < gemm_k_iterations; k++) {
 
  154    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  156    loader_a.load_unsafe();
 
  157    loader_b.load_unsafe();
 
  159    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  169  threadgroup_barrier(mem_flags::mem_none);
 
  172  short tgp_bm = 
min(BM, gemm_params->M - c_row);
 
  173  short tgp_bn = 
min(BN, gemm_params->N - c_col);
 
  174  const int ldc = N * params->groups;
 
  175  mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm));
 
 
void implicit_gemm_conv_2d(const device T *A, const device T *B, device T *C, const constant MLXConvParams< 2 > *params, const constant ImplicitGemmConv2DParams *gemm_params, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
Definition steel_conv.h:17