18 const device T* A [[buffer(0)]],
19 const device T* B [[buffer(1)]],
20 device T* C [[buffer(2)]],
22 const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
23 uint3 tid [[threadgroup_position_in_grid]],
24 uint3 lid [[thread_position_in_threadgroup]],
25 uint simd_gid [[simdgroup_index_in_threadgroup]],
26 uint simd_lid [[thread_index_in_simdgroup]]) {
31 constexpr bool transpose_a =
false;
32 constexpr bool transpose_b =
true;
33 constexpr short tgp_padding_a = 16 /
sizeof(T);
34 constexpr short tgp_padding_b = 16 /
sizeof(T);
36 constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
37 constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
38 constexpr short shape_a_rows = (transpose_a ? BK : BM);
39 constexpr short shape_b_rows = (transpose_b ? BN : BK);
40 constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
41 constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
43 constexpr short tgp_size = WM * WN * 32;
47 using loader_a_t =
typename metal::conditional_t<
49 N_CHANNELS != 0 && N_CHANNELS <= 4,
62 typename metal::conditional_t<
85 using loader_b_t =
typename metal::conditional_t<
87 N_CHANNELS != 0 && N_CHANNELS <= 4,
115 threadgroup T As[tgp_mem_size_a];
116 threadgroup T Bs[tgp_mem_size_b];
118 const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
119 ((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
120 const int tid_x = (tid.x) >> gemm_params->swizzle_log;
122 if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
126 const int c_row = tid_y * BM;
127 const int c_col = tid_x * BN;
128 const int K = gemm_params->K;
129 const int N = gemm_params->N;
130 const int C_per_group = params->C / params->groups;
133 A += tid.z * C_per_group;
138 C += c_row * (N * params->groups) + c_col;
140 const int2 offsets_a(0, c_row);
141 const int2 offsets_b(0, c_col);
145 A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
147 B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
150 mma_t mma_op(simd_gid, simd_lid);
152 int gemm_k_iterations = gemm_params->gemm_k_iterations;
153 for (
int k = 0; k < gemm_k_iterations; k++) {
154 threadgroup_barrier(mem_flags::mem_threadgroup);
156 loader_a.load_unsafe();
157 loader_b.load_unsafe();
159 threadgroup_barrier(mem_flags::mem_threadgroup);
169 threadgroup_barrier(mem_flags::mem_none);
172 short tgp_bm =
min(BM, gemm_params->M - c_row);
173 short tgp_bn =
min(BN, gemm_params->N - c_col);
174 const int ldc = N * params->groups;
175 mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm));
void implicit_gemm_conv_2d(const device T *A, const device T *B, device T *C, const constant MLXConvParams< 2 > *params, const constant ImplicitGemmConv2DParams *gemm_params, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
Definition steel_conv.h:17