81 threadgroup T* As [[threadgroup(0)]],
82 threadgroup T* Bs [[threadgroup(1)]],
83 const int gemm_k_iterations,
87 thread
const short& tgp_bm,
88 thread
const short& tgp_bn,
89 thread
const short& lbk,
94 short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
96 short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
98 for (
int k = 0; k < gemm_k_iterations; k++) {
99 threadgroup_barrier(mem_flags::mem_threadgroup);
102 loader_a.load_unsafe();
104 loader_a.load_safe(tile_dims_A);
108 loader_b.load_unsafe();
110 loader_b.load_safe(tile_dims_B);
113 threadgroup_barrier(mem_flags::mem_threadgroup);
124 threadgroup_barrier(mem_flags::mem_threadgroup);
126 short2 tile_dims_A_last =
127 transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
128 short2 tile_dims_B_last =
129 transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
131 loader_a.load_safe(tile_dims_A_last);
132 loader_b.load_safe(tile_dims_B_last);
134 threadgroup_barrier(mem_flags::mem_threadgroup);
141 static METAL_FUNC
void run(
142 const device T* A [[buffer(0)]],
143 const device T* B [[buffer(1)]],
144 device U* D [[buffer(2)]],
145 const constant
GEMMParams* params [[buffer(3)]],
146 threadgroup T* As [[threadgroup(0)]],
147 threadgroup T* Bs [[threadgroup(1)]],
148 uint simd_lane_id [[thread_index_in_simdgroup]],
149 uint simd_group_id [[simdgroup_index_in_threadgroup]],
150 uint3 tid [[threadgroup_position_in_grid]],
151 uint3 lid [[thread_position_in_threadgroup]]) {
155 const int tid_y = ((tid.y) << params->swizzle_log) +
156 ((tid.x) & ((1 << params->swizzle_log) - 1));
157 const int tid_x = (tid.x) >> params->swizzle_log;
159 if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
163 threadgroup_barrier(mem_flags::mem_none);
166 const int c_row = tid_y * BM;
167 const int c_col = tid_x * BN;
168 const size_t c_row_long = size_t(c_row);
169 const size_t c_col_long = size_t(c_col);
171 A += transpose_a ? c_row_long : c_row_long * params->lda;
172 B += transpose_b ? c_col_long * params->ldb : c_col_long;
173 D += c_row_long * params->ldd + c_col_long;
176 thread
loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
177 thread
loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
180 thread
mma_t mma_op(simd_group_id, simd_lane_id);
182 int gemm_k_iterations = params->gemm_k_iterations_aligned;
187 for (
int k = 0; k < gemm_k_iterations; k++) {
188 threadgroup_barrier(mem_flags::mem_threadgroup);
190 loader_a.load_unsafe();
191 loader_b.load_unsafe();
193 threadgroup_barrier(mem_flags::mem_threadgroup);
203 threadgroup_barrier(mem_flags::mem_none);
207 int lbk = params->K - params->gemm_k_iterations_aligned * BK;
208 short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
209 short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
211 loader_a.load_safe(tile_dims_A);
212 loader_b.load_safe(tile_dims_B);
214 threadgroup_barrier(mem_flags::mem_threadgroup);
220 mma_op.store_result(D, params->ldd);
227 short tgp_bm =
min(BM, params->M - c_row);
228 short tgp_bn =
min(BN, params->N - c_col);
229 short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
231 if (tgp_bm == BM && tgp_bn == BN) {
243 mma_op.store_result(D, params->ldd);
246 }
else if (tgp_bn == BN) {
258 mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
261 }
else if (tgp_bm == BM) {
273 mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
288 mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));