33[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]]
void gemm(
34 const device T* A [[buffer(0)]],
35 const device T* B [[buffer(1)]],
36 const device T* C [[buffer(2), function_constant(
use_out_source)]],
37 device T* D [[buffer(3)]],
38 const constant
GEMMParams* params [[buffer(4)]],
40 const constant
int* batch_shape [[buffer(6)]],
41 const constant
size_t* batch_strides [[buffer(7)]],
42 const constant uint32_t* lhs_indices [[buffer(10), function_constant(
do_gather)]],
43 const constant uint32_t* rhs_indices [[buffer(11), function_constant(
do_gather)]],
44 const constant uint32_t* C_indices [[buffer(12), function_constant(
gather_bias)]],
45 const constant
int* operand_shape [[buffer(13), function_constant(
do_gather)]],
46 const constant
size_t* operand_strides [[buffer(14), function_constant(
do_gather)]],
47 const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(
do_gather)]],
48 uint simd_lane_id [[thread_index_in_simdgroup]],
49 uint simd_group_id [[simdgroup_index_in_threadgroup]],
50 uint3 tid [[threadgroup_position_in_grid]],
51 uint3 lid [[thread_position_in_threadgroup]]) {
69 using loader_a_t =
typename gemm_kernel::loader_a_t;
70 using loader_b_t =
typename gemm_kernel::loader_b_t;
71 using mma_t =
typename gemm_kernel::mma_t;
74 const int tid_y = ((tid.y) << params->swizzle_log) +
75 ((tid.x) & ((1 << params->swizzle_log) - 1));
76 const int tid_x = (tid.x) >> params->swizzle_log;
79 if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
88 uint32_t indx_A, indx_B, indx_C;
91 const constant
size_t* indx_A_bstrides = batch_strides;
92 const constant
size_t* indx_B_bstrides =
93 batch_strides + params->batch_ndim;
101 indx_A = lhs_indices[indx_offsets.x];
102 indx_B = rhs_indices[indx_offsets.y];
105 const constant
size_t* indx_C_bstrides =
106 indx_B_bstrides + params->batch_ndim;
108 tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
109 indx_C = C_indices[indx_offset_C];
112 indx_A = lhs_indices[params->batch_stride_a * tid.z];
113 indx_B = rhs_indices[params->batch_stride_b * tid.z];
116 indx_C = C_indices[addmm_params->batch_stride_c * tid.z];
121 int batch_ndim_A = operand_batch_ndim.x;
122 const constant
int* batch_shape_A = operand_shape;
123 const constant
size_t* batch_strides_A = operand_strides;
124 A +=
elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
126 int batch_ndim_B = operand_batch_ndim.y;
127 const constant
int* batch_shape_B = batch_shape_A + batch_ndim_A;
128 const constant
size_t* batch_strides_B = batch_strides_A + batch_ndim_A;
129 B +=
elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
132 int batch_ndim_C = operand_batch_ndim.z;
133 const constant
int* batch_shape_C = batch_shape_B + batch_ndim_B;
134 const constant
size_t* batch_strides_C = batch_strides_B + batch_ndim_B;
135 C +=
elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
143 const constant
size_t* A_bstrides = batch_strides;
144 const constant
size_t* B_bstrides = batch_strides + params->batch_ndim;
147 tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
149 A += batch_offsets.x;
150 B += batch_offsets.y;
153 const constant
size_t* C_bstrides = B_bstrides + params->batch_ndim;
154 C +=
elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
157 A += params->batch_stride_a * tid.z;
158 B += params->batch_stride_b * tid.z;
161 C += addmm_params->batch_stride_c * tid.z;
166 D += params->batch_stride_d * tid.z;
169 threadgroup T As[gemm_kernel::tgp_mem_size_a];
170 threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
172 threadgroup_barrier(mem_flags::mem_none);
175 const int c_row = tid_y * BM;
176 const int c_col = tid_x * BN;
177 const size_t c_row_long = size_t(c_row);
178 const size_t c_col_long = size_t(c_col);
180 A += transpose_a ? c_row_long : c_row_long * params->lda;
181 B += transpose_b ? c_col_long * params->ldb : c_col_long;
182 D += c_row_long * params->ldd + c_col_long;
185 C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
189 thread mma_t mma_op(simd_group_id, simd_lane_id);
192 thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
193 thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
196 const short tgp_bm =
align_M ? BM : short(min(BM, params->M - c_row));
197 const short tgp_bn =
align_N ? BN : short(min(BN, params->N - c_col));
200 int gemm_k_iterations = params->gemm_k_iterations_aligned;
204 const int k_last = params->gemm_k_iterations_aligned * BK;
205 const int k_remain = params->K - k_last;
206 const size_t k_jump_a =
207 transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
208 const size_t k_jump_b =
209 transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
212 loader_a.src += k_jump_a;
213 loader_b.src += k_jump_b;
216 const short2 tile_dims_A =
217 transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
218 const short2 tile_dims_B =
219 transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
221 loader_a.load_safe(tile_dims_A);
222 loader_b.load_safe(tile_dims_B);
224 threadgroup_barrier(mem_flags::mem_threadgroup);
230 loader_a.src -= k_jump_a;
231 loader_b.src -= k_jump_b;
235 addmm_params->alpha, addmm_params->beta);
237 addmm_params->alpha, addmm_params->beta);
243 for (
int k = 0; k < gemm_k_iterations; k++) {
244 threadgroup_barrier(mem_flags::mem_threadgroup);
246 loader_a.load_unsafe();
247 loader_b.load_unsafe();
249 threadgroup_barrier(mem_flags::mem_threadgroup);
259 threadgroup_barrier(mem_flags::mem_none);
264 mma_op.apply_epilogue(
265 C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
267 mma_op.apply_epilogue(
268 C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
273 return mma_op.store_result(D, params->ldd);
279 const int leftover_bk = 0;
283 gemm_kernel::gemm_loop(
298 mma_op.apply_epilogue(
299 C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
301 mma_op.apply_epilogue(
302 C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
307 return mma_op.store_result(D, params->ldd);
309 }
else if (
align_N || tgp_bn == BN) {
310 gemm_kernel::gemm_loop(
325 mma_op.apply_epilogue_safe(
329 short2(tgp_bn, tgp_bm),
332 mma_op.apply_epilogue_safe(
336 short2(tgp_bn, tgp_bm),
342 return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
344 }
else if (
align_M || tgp_bm == BM) {
345 gemm_kernel::gemm_loop(
360 mma_op.apply_epilogue_safe(
364 short2(tgp_bn, tgp_bm),
367 mma_op.apply_epilogue_safe(
371 short2(tgp_bn, tgp_bm),
377 return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
380 gemm_kernel::gemm_loop(
395 mma_op.apply_epilogue_safe(
399 short2(tgp_bn, tgp_bm),
402 mma_op.apply_epilogue_safe(
406 short2(tgp_bn, tgp_bm),
412 return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));