16 const device T* A [[buffer(0)]],
17 const device T* B [[buffer(1)]],
18 device T* C [[buffer(2)]],
20 const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
21 const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
22 const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
23 const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
24 uint3 tid [[threadgroup_position_in_grid]],
25 uint3 lid [[thread_position_in_threadgroup]],
26 uint simd_gid [[simdgroup_index_in_threadgroup]],
27 uint simd_lid [[thread_index_in_simdgroup]]) {
30 constexpr bool transpose_a =
false;
31 constexpr bool transpose_b =
true;
32 constexpr short tgp_padding_a = 16 /
sizeof(T);
33 constexpr short tgp_padding_b = 16 /
sizeof(T);
35 constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
36 constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
37 constexpr short shape_a_rows = (transpose_a ? BK : BM);
38 constexpr short shape_b_rows = (transpose_b ? BN : BK);
39 constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
40 constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
42 constexpr short tgp_size = WM * WN * 32;
46 Conv2DInputBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_a>;
50 Conv2DWeightBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_b>;
52 using mma_t = BlockMMA<
65 threadgroup T As[tgp_mem_size_a];
66 threadgroup T Bs[tgp_mem_size_b];
68 const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
69 ((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
70 const int tid_x = (tid.x) >> gemm_params->swizzle_log;
72 if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
76 const int tid_z = tid.z;
78 const int base_oh = tid_z / jump_params->f_out_jump_w;
79 const int base_ow = tid_z % jump_params->f_out_jump_w;
81 const int base_wh = base_h[base_oh].weight_base;
82 const int base_ww = base_w[base_ow].weight_base;
84 const int base_wh_size = base_h[base_oh].weight_size;
85 const int base_ww_size = base_w[base_ow].weight_size;
87 const int c_row = tid_y * BM;
88 const int c_col = tid_x * BN;
89 const int K = gemm_params->K;
93 const int4 offsets_a(0, c_row, base_oh, base_ow);
94 const int2 offsets_b(0, c_col);
119 mma_t mma_op(simd_gid, simd_lid);
121 int gemm_k_iterations =
122 base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
124 for (
int k = 0; k < gemm_k_iterations; k++) {
125 threadgroup_barrier(mem_flags::mem_threadgroup);
127 loader_a.load_unsafe();
128 loader_b.load_unsafe();
130 threadgroup_barrier(mem_flags::mem_threadgroup);
140 threadgroup_barrier(mem_flags::mem_none);
145 int offset_m = c_row + mma_op.sm;
146 int offset_n = c_col + mma_op.sn;
149 if (offset_n >= gemm_params->N)
152 short diff = gemm_params->N - offset_n;
155 for (
int i = 0; i < mma_t::TM; i++) {
156 int cm = offset_m + i * mma_t::TM_stride;
158 int n = cm / jump_params->adj_out_hw;
159 int hw = cm % jump_params->adj_out_hw;
161 (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
163 (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
165 if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
166 int offset_cm = n * params->out_strides[0] +
167 oh * params->out_strides[1] + ow * params->out_strides[2];
170 for (
int j = 0; j < mma_t::TN; j++) {
172 thread
const auto& accum = mma_op.Ctile.frag_at(i, j);
173 int offset = offset_cm + (j * mma_t::TN_stride);
175 constexpr short kelems =
decltype(mma_op.Ctile)::kElemsPerFrag;
179 for (
short k = 0; k < kelems; k++) {
180 if ((j * mma_t::TN_stride + k) < diff) {
181 C[offset + k] = Epilogue::apply(accum[k]);
void implicit_gemm_conv_2d_general(const device T *A, const device T *B, device T *C, const constant MLXConvParams< 2 > *params, const constant ImplicitGemmConv2DParams *gemm_params, const constant Conv2DGeneralJumpParams *jump_params, const constant Conv2DGeneralBaseInfo *base_h, const constant Conv2DGeneralBaseInfo *base_w, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
Definition steel_conv_general.h:15