33[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]]
void gemm(
34 const device T* A [[buffer(0)]],
35 const device T* B [[buffer(1)]],
36 const device T* C [[buffer(2), function_constant(
use_out_source)]],
37 device T* D [[buffer(3)]],
38 const constant
GEMMParams* params [[buffer(4)]],
40 const constant
int* batch_shape [[buffer(6)]],
41 const constant int64_t* batch_strides [[buffer(7)]],
42 const constant uint32_t* lhs_indices [[buffer(10), function_constant(
do_gather)]],
43 const constant uint32_t* rhs_indices [[buffer(11), function_constant(
do_gather)]],
44 const constant uint32_t* C_indices [[buffer(12), function_constant(
gather_bias)]],
45 const constant
int* operand_shape [[buffer(13), function_constant(
do_gather)]],
46 const constant int64_t* operand_strides [[buffer(14), function_constant(
do_gather)]],
47 const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(
do_gather)]],
48 uint simd_lane_id [[thread_index_in_simdgroup]],
49 uint simd_group_id [[simdgroup_index_in_threadgroup]],
50 uint3 tid [[threadgroup_position_in_grid]],
51 uint3 lid [[thread_position_in_threadgroup]]) {
69 using loader_a_t =
typename gemm_kernel::loader_a_t;
70 using loader_b_t =
typename gemm_kernel::loader_b_t;
71 using mma_t =
typename gemm_kernel::mma_t;
74 const int tid_y = ((tid.y) << params->swizzle_log) +
75 ((tid.x) & ((1 << params->swizzle_log) - 1));
76 const int tid_x = (tid.x) >> params->swizzle_log;
79 if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
88 uint32_t indx_A, indx_B, indx_C;
91 const constant
auto* indx_A_bstrides = batch_strides;
92 const constant
auto* indx_B_bstrides = batch_strides + params->batch_ndim;
100 indx_A = lhs_indices[indx_offsets.x];
101 indx_B = rhs_indices[indx_offsets.y];
104 const constant
auto* indx_C_bstrides =
105 indx_B_bstrides + params->batch_ndim;
107 tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
108 indx_C = C_indices[indx_offset_C];
111 indx_A = lhs_indices[params->batch_stride_a * tid.z];
112 indx_B = rhs_indices[params->batch_stride_b * tid.z];
115 indx_C = C_indices[addmm_params->batch_stride_c * tid.z];
120 int batch_ndim_A = operand_batch_ndim.x;
121 const constant
int* batch_shape_A = operand_shape;
122 const constant
auto* batch_strides_A = operand_strides;
123 A +=
elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
125 int batch_ndim_B = operand_batch_ndim.y;
126 const constant
int* batch_shape_B = batch_shape_A + batch_ndim_A;
127 const constant
auto* batch_strides_B = batch_strides_A + batch_ndim_A;
128 B +=
elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
131 int batch_ndim_C = operand_batch_ndim.z;
132 const constant
int* batch_shape_C = batch_shape_B + batch_ndim_B;
133 const constant
auto* batch_strides_C = batch_strides_B + batch_ndim_B;
134 C +=
elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
142 const constant
auto* A_bstrides = batch_strides;
143 const constant
auto* B_bstrides = batch_strides + params->batch_ndim;
146 tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
148 A += batch_offsets.x;
149 B += batch_offsets.y;
152 const constant
auto* C_bstrides = B_bstrides + params->batch_ndim;
153 C +=
elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
156 A += params->batch_stride_a * tid.z;
157 B += params->batch_stride_b * tid.z;
160 C += addmm_params->batch_stride_c * tid.z;
165 D += params->batch_stride_d * tid.z;
168 threadgroup T As[gemm_kernel::tgp_mem_size_a];
169 threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
171 threadgroup_barrier(mem_flags::mem_none);
174 const int c_row = tid_y * BM;
175 const int c_col = tid_x * BN;
176 const size_t c_row_long = size_t(c_row);
177 const size_t c_col_long = size_t(c_col);
179 A += transpose_a ? c_row_long : c_row_long * params->lda;
180 B += transpose_b ? c_col_long * params->ldb : c_col_long;
181 D += c_row_long * params->ldd + c_col_long;
184 C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
188 thread mma_t mma_op(simd_group_id, simd_lane_id);
191 thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
192 thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
195 const short tgp_bm =
align_M ? BM : short(min(BM, params->M - c_row));
196 const short tgp_bn =
align_N ? BN : short(min(BN, params->N - c_col));
199 int gemm_k_iterations = params->gemm_k_iterations_aligned;
203 const int k_last = params->gemm_k_iterations_aligned * BK;
204 const int k_remain = params->K - k_last;
205 const size_t k_jump_a =
206 transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
207 const size_t k_jump_b =
208 transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
211 loader_a.src += k_jump_a;
212 loader_b.src += k_jump_b;
215 const short2 tile_dims_A =
216 transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
217 const short2 tile_dims_B =
218 transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
220 loader_a.load_safe(tile_dims_A);
221 loader_b.load_safe(tile_dims_B);
223 threadgroup_barrier(mem_flags::mem_threadgroup);
229 loader_a.src -= k_jump_a;
230 loader_b.src -= k_jump_b;
234 addmm_params->alpha, addmm_params->beta);
236 addmm_params->alpha, addmm_params->beta);
242 for (
int k = 0; k < gemm_k_iterations; k++) {
243 threadgroup_barrier(mem_flags::mem_threadgroup);
245 loader_a.load_unsafe();
246 loader_b.load_unsafe();
248 threadgroup_barrier(mem_flags::mem_threadgroup);
258 threadgroup_barrier(mem_flags::mem_none);
263 mma_op.apply_epilogue(
264 C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
266 mma_op.apply_epilogue(
267 C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
272 return mma_op.store_result(D, params->ldd);
278 const int leftover_bk = 0;
282 gemm_kernel::gemm_loop(
297 mma_op.apply_epilogue(
298 C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
300 mma_op.apply_epilogue(
301 C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
306 return mma_op.store_result(D, params->ldd);
308 }
else if (
align_N || tgp_bn == BN) {
309 gemm_kernel::gemm_loop(
324 mma_op.apply_epilogue_safe(
328 short2(tgp_bn, tgp_bm),
331 mma_op.apply_epilogue_safe(
335 short2(tgp_bn, tgp_bm),
341 return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
343 }
else if (
align_M || tgp_bm == BM) {
344 gemm_kernel::gemm_loop(
359 mma_op.apply_epilogue_safe(
363 short2(tgp_bn, tgp_bm),
366 mma_op.apply_epilogue_safe(
370 short2(tgp_bn, tgp_bm),
376 return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
379 gemm_kernel::gemm_loop(
394 mma_op.apply_epilogue_safe(
398 short2(tgp_bn, tgp_bm),
401 mma_op.apply_epilogue_safe(
405 short2(tgp_bn, tgp_bm),
411 return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));