21[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]]
void gemm_splitk(
22 const device T* A [[buffer(0)]],
23 const device T* B [[buffer(1)]],
24 device U* C [[buffer(2)]],
26 uint simd_lane_id [[thread_index_in_simdgroup]],
27 uint simd_group_id [[simdgroup_index_in_threadgroup]],
28 uint3 tid [[threadgroup_position_in_grid]],
29 uint3 lid [[thread_position_in_threadgroup]]) {
44 using loader_a_t =
typename gemm_kernel::loader_a_t;
45 using loader_b_t =
typename gemm_kernel::loader_b_t;
46 using mma_t =
typename gemm_kernel::mma_t;
48 threadgroup T As[gemm_kernel::tgp_mem_size_a];
49 threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
51 const int tid_x = tid.x;
52 const int tid_y = tid.y;
53 const int tid_z = tid.z;
55 if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
60 const int c_row = tid_y * BM;
61 const int c_col = tid_x * BN;
62 const int k_start = params->split_k_partition_size * tid_z;
64 const size_t c_row_long = size_t(c_row);
65 const size_t c_col_long = size_t(c_col);
66 const size_t k_start_long = size_t(k_start);
68 A += transpose_a ? (c_row_long + k_start_long * params->lda)
69 : (k_start_long + c_row_long * params->lda);
70 B += transpose_b ? (k_start_long + c_col_long * params->ldb)
71 : (c_col_long + k_start_long * params->ldb);
72 C += (size_t(params->split_k_partition_stride) * tid_z) +
73 (c_row_long * params->ldc + c_col_long);
76 thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
77 thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
80 thread mma_t mma_op(simd_group_id, simd_lane_id);
82 int gemm_k_iterations = params->gemm_k_iterations_aligned;
84 short tgp_bm = min(BM, params->M - c_row);
85 short tgp_bn = min(BN, params->N - c_col);
86 short leftover_bk = params->K % BK;
88 if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
89 gemm_kernel::gemm_loop(
100 }
else if (tgp_bn == BN) {
101 gemm_kernel::gemm_loop(
112 }
else if (tgp_bm == BM) {
113 gemm_kernel::gemm_loop(
125 gemm_kernel::gemm_loop(
138 threadgroup_barrier(mem_flags::mem_threadgroup);
140 if ((tid_z + 1) == (params->split_k_partitions)) {
141 int gemm_k_iter_remaining =
142 (params->K - (k_start + params->split_k_partition_size)) / BK;
143 if (!K_aligned || gemm_k_iter_remaining > 0)
144 gemm_kernel::gemm_loop(
147 gemm_k_iter_remaining,
157 if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
158 mma_op.store_result(C, params->ldc);
160 mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
173 const device AccT* C_split [[buffer(0)]],
174 device OutT* D [[buffer(1)]],
175 const constant
int& k_partitions [[buffer(2)]],
176 const constant
int& partition_stride [[buffer(3)]],
177 const constant
int& ldd [[buffer(4)]],
178 uint2 gid [[thread_position_in_grid]]) {
180 D += gid.x + gid.y * size_t(ldd);
181 C_split += gid.x + gid.y * size_t(ldd);
186 for (
int i = 0; i < k_partitions; i++) {
187 out += C_split[offset];
188 offset += partition_stride;
192 D[0] = Epilogue::apply(out);
200 const device AccT* C_split [[buffer(0)]],
201 device OutT* D [[buffer(1)]],
202 const constant
int& k_partitions [[buffer(2)]],
203 const constant
int& partition_stride [[buffer(3)]],
204 const constant
int& ldd [[buffer(4)]],
205 const device OutT* C [[buffer(5)]],
206 const constant
int& ldc [[buffer(6)]],
207 const constant
int& fdc [[buffer(7)]],
208 const constant
float& alpha [[buffer(8)]],
209 const constant
float& beta [[buffer(9)]],
210 uint2 gid [[thread_position_in_grid]]) {
212 C += gid.x * size_t(fdc) + gid.y * size_t(ldc);
213 D += gid.x + gid.y * size_t(ldd);
214 C_split += gid.x + gid.y * size_t(ldd);
219 for (
int i = 0; i < k_partitions; i++) {
220 out += C_split[offset];
221 offset += partition_stride;
225 Epilogue
op(alpha, beta);
226 D[0] =
op.apply(out, *C);
void gemm_splitk(const device T *A, const device T *B, device U *C, const constant GEMMSpiltKParams *params, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 lid)
Definition steel_gemm_splitk.h:21
void gemm_splitk_accum(const device AccT *C_split, device OutT *D, const constant int &k_partitions, const constant int &partition_stride, const constant int &ldd, uint2 gid)
Definition steel_gemm_splitk.h:172
void gemm_splitk_accum_axpby(const device AccT *C_split, device OutT *D, const constant int &k_partitions, const constant int &partition_stride, const constant int &ldd, const device OutT *C, const constant int &ldc, const constant int &fdc, const constant float &alpha, const constant float &beta, uint2 gid)
Definition steel_gemm_splitk.h:199