16    const device T* A [[buffer(0)]],
 
   17    const device T* B [[buffer(1)]],
 
   18    device T* C [[buffer(2)]],
 
   20    const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
 
   21    const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
 
   22    const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
 
   23    const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
 
   24    uint3 tid [[threadgroup_position_in_grid]],
 
   25    uint3 lid [[thread_position_in_threadgroup]],
 
   26    uint simd_gid [[simdgroup_index_in_threadgroup]],
 
   27    uint simd_lid [[thread_index_in_simdgroup]]) {
 
   30  constexpr bool transpose_a = 
false;
 
   31  constexpr bool transpose_b = 
true;
 
   32  constexpr short tgp_padding_a = 16 / 
sizeof(T);
 
   33  constexpr short tgp_padding_b = 16 / 
sizeof(T);
 
   35  constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
 
   36  constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
 
   37  constexpr short shape_a_rows = (transpose_a ? BK : BM);
 
   38  constexpr short shape_b_rows = (transpose_b ? BN : BK);
 
   39  constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
 
   40  constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
 
   42  constexpr short tgp_size = WM * WN * 32;
 
   46      Conv2DInputBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_a>;
 
   50      Conv2DWeightBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_b>;
 
   52  using mma_t = BlockMMA<
 
   65  threadgroup T As[tgp_mem_size_a];
 
   66  threadgroup T Bs[tgp_mem_size_b];
 
   68  const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
 
   69      ((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
 
   70  const int tid_x = (tid.x) >> gemm_params->swizzle_log;
 
   72  if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
 
   76  const int tid_z = tid.z;
 
   78  const int base_oh = tid_z / jump_params->f_out_jump_w;
 
   79  const int base_ow = tid_z % jump_params->f_out_jump_w;
 
   81  const int base_wh = base_h[base_oh].weight_base;
 
   82  const int base_ww = base_w[base_ow].weight_base;
 
   84  const int base_wh_size = base_h[base_oh].weight_size;
 
   85  const int base_ww_size = base_w[base_ow].weight_size;
 
   87  const int c_row = tid_y * BM;
 
   88  const int c_col = tid_x * BN;
 
   89  const int K = gemm_params->K;
 
   93  const int4 offsets_a(0, c_row, base_oh, base_ow);
 
   94  const int2 offsets_b(0, c_col);
 
  119  mma_t mma_op(simd_gid, simd_lid);
 
  121  int gemm_k_iterations =
 
  122      base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
 
  124  for (
int k = 0; k < gemm_k_iterations; k++) {
 
  125    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  127    loader_a.load_unsafe();
 
  128    loader_b.load_unsafe();
 
  130    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  140  threadgroup_barrier(mem_flags::mem_none);
 
  145    int offset_m = c_row + mma_op.sm + mma_op.tm;
 
  146    int offset_n = c_col + mma_op.sn + mma_op.tn;
 
  149    if (offset_n >= gemm_params->N)
 
  152    short diff = gemm_params->N - offset_n;
 
  155    for (
int i = 0; i < mma_t::TM; i++) {
 
  156      int cm = offset_m + i * mma_t::TM_stride;
 
  158      int n = cm / jump_params->adj_out_hw;
 
  159      int hw = cm % jump_params->adj_out_hw;
 
  161          (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
 
  163          (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
 
  165      if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
 
  166        int offset_cm = n * params->out_strides[0] +
 
  167            oh * params->out_strides[1] + ow * params->out_strides[2];
 
  170        for (
int j = 0; j < mma_t::TN; j++) {
 
  172          thread 
const auto& accum =
 
  173              mma_op.results[i * mma_t::TN + j].thread_elements();
 
  174          int offset = offset_cm + (j * mma_t::TN_stride);
 
  177          if (j * mma_t::TN_stride < diff) {
 
  178            C[offset] = Epilogue::apply(accum[0]);
 
  181          if (j * mma_t::TN_stride + 1 < diff) {
 
  182            C[offset + 1] = Epilogue::apply(accum[1]);
 
 
void implicit_gemm_conv_2d_general(const device T *A, const device T *B, device T *C, const constant MLXConvParams< 2 > *params, const constant ImplicitGemmConv2DParams *gemm_params, const constant Conv2DGeneralJumpParams *jump_params, const constant Conv2DGeneralBaseInfo *base_h, const constant Conv2DGeneralBaseInfo *base_w, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
Definition steel_conv_general.h:15