80 threadgroup T* As [[threadgroup(0)]],
81 threadgroup T* Bs [[threadgroup(1)]],
82 const int gemm_k_iterations,
86 thread
const short& tgp_bm,
87 thread
const short& tgp_bn,
88 thread
const short& lbk,
93 short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
95 short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
97 for (
int k = 0; k < gemm_k_iterations; k++) {
98 threadgroup_barrier(mem_flags::mem_threadgroup);
101 loader_a.load_unsafe();
103 loader_a.load_safe(tile_dims_A);
107 loader_b.load_unsafe();
109 loader_b.load_safe(tile_dims_B);
112 threadgroup_barrier(mem_flags::mem_threadgroup);
123 threadgroup_barrier(mem_flags::mem_threadgroup);
125 short2 tile_dims_A_last =
126 transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
127 short2 tile_dims_B_last =
128 transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
130 loader_a.load_safe(tile_dims_A_last);
131 loader_b.load_safe(tile_dims_B_last);
133 threadgroup_barrier(mem_flags::mem_threadgroup);
140 static METAL_FUNC
void run(
141 const device T* A [[buffer(0)]],
142 const device T* B [[buffer(1)]],
143 device U* D [[buffer(2)]],
144 const constant
GEMMParams* params [[buffer(3)]],
145 threadgroup T* As [[threadgroup(0)]],
146 threadgroup T* Bs [[threadgroup(1)]],
147 uint simd_lane_id [[thread_index_in_simdgroup]],
148 uint simd_group_id [[simdgroup_index_in_threadgroup]],
149 uint3 tid [[threadgroup_position_in_grid]],
150 uint3 lid [[thread_position_in_threadgroup]]) {
154 const int tid_y = ((tid.y) << params->swizzle_log) +
155 ((tid.x) & ((1 << params->swizzle_log) - 1));
156 const int tid_x = (tid.x) >> params->swizzle_log;
158 if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
162 threadgroup_barrier(mem_flags::mem_none);
165 const int c_row = tid_y * BM;
166 const int c_col = tid_x * BN;
167 const size_t c_row_long = size_t(c_row);
168 const size_t c_col_long = size_t(c_col);
170 A += transpose_a ? c_row_long : c_row_long * params->lda;
171 B += transpose_b ? c_col_long * params->ldb : c_col_long;
172 D += c_row_long * params->ldd + c_col_long;
175 thread
loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
176 thread
loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
179 thread
mma_t mma_op(simd_group_id, simd_lane_id);
181 int gemm_k_iterations = params->gemm_k_iterations_aligned;
186 for (
int k = 0; k < gemm_k_iterations; k++) {
187 threadgroup_barrier(mem_flags::mem_threadgroup);
189 loader_a.load_unsafe();
190 loader_b.load_unsafe();
192 threadgroup_barrier(mem_flags::mem_threadgroup);
202 threadgroup_barrier(mem_flags::mem_none);
206 int lbk = params->K - params->gemm_k_iterations_aligned * BK;
207 short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
208 short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
210 loader_a.load_safe(tile_dims_A);
211 loader_b.load_safe(tile_dims_B);
213 threadgroup_barrier(mem_flags::mem_threadgroup);
219 mma_op.store_result(D, params->ldd);
226 short tgp_bm =
min(BM, params->M - c_row);
227 short tgp_bn =
min(BN, params->N - c_col);
228 short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
230 if (tgp_bm == BM && tgp_bn == BN) {
242 mma_op.store_result(D, params->ldd);
245 }
else if (tgp_bn == BN) {
257 mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
260 }
else if (tgp_bm == BM) {
272 mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
287 mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));