54 const device T* A [[buffer(0)]],
55 const device T* B [[buffer(1)]],
56 device T* D [[buffer(3)]],
57 const constant
GEMMParams* params [[buffer(4)]],
58 const constant
int* batch_shape [[buffer(6)]],
59 const constant
size_t* batch_strides [[buffer(7)]],
60 const device out_mask_t* out_mask [[buffer(10)]],
61 const device op_mask_t* lhs_mask [[buffer(11)]],
62 const device op_mask_t* rhs_mask [[buffer(12)]],
63 const constant
int* mask_strides [[buffer(13)]],
64 uint simd_lane_id [[thread_index_in_simdgroup]],
65 uint simd_group_id [[simdgroup_index_in_threadgroup]],
66 uint3 tid [[threadgroup_position_in_grid]],
67 uint3 lid [[thread_position_in_threadgroup]]) {
73 "block_masked_gemm must have the same block M and block N size");
74 static_assert(BM % BK == 0,
"block_masked_gemm must have BM % BK == 0");
76 constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
77 constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
79 constexpr bool has_mul_operand_mask =
80 has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
81 constexpr bool has_mul_output_mask =
82 has_output_mask && !metal::is_same_v<out_mask_t, bool>;
84 constexpr short k_mask_factor = short(BM / BK);
99 const int tid_y = ((tid.y) << params->swizzle_log) +
100 ((tid.x) & ((1 << params->swizzle_log) - 1));
101 const int tid_x = (tid.x) >> params->swizzle_log;
103 if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
107 const constant
size_t* mask_batch_strides =
108 batch_strides + 2 * params->batch_ndim;
110 if (params->batch_ndim > 1) {
111 if (has_output_mask) {
113 tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
115 mask_batch_strides += params->batch_ndim;
118 if (has_operand_mask) {
119 const constant
size_t* mask_strides_lhs = mask_batch_strides;
120 const constant
size_t* mask_strides_rhs =
121 mask_strides_lhs + params->batch_ndim;
130 lhs_mask += batch_offsets.x;
131 rhs_mask += batch_offsets.y;
134 if (has_output_mask) {
135 out_mask += tid.z * mask_batch_strides[0];
136 mask_batch_strides += params->batch_ndim;
139 if (has_operand_mask) {
140 lhs_mask += tid.z * mask_batch_strides[0];
141 rhs_mask += tid.z * mask_batch_strides[params->batch_ndim];
146 if (params->batch_ndim > 1) {
147 const constant
size_t* A_bstrides = batch_strides;
148 const constant
size_t* B_bstrides = batch_strides + params->batch_ndim;
151 tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
153 A += batch_offsets.x;
154 B += batch_offsets.y;
157 A += params->batch_stride_a * tid.z;
158 B += params->batch_stride_b * tid.z;
161 D += params->batch_stride_d * tid.z;
164 const int c_row = tid_y * BM;
165 const int c_col = tid_x * BN;
166 const size_t c_row_long = size_t(c_row);
167 const size_t c_col_long = size_t(c_col);
169 A += transpose_a ? c_row_long : c_row_long * params->lda;
170 B += transpose_b ? c_col_long * params->ldb : c_col_long;
171 D += c_row_long * params->ldd + c_col_long;
173 const constant
int* out_mask_strides = mask_strides;
174 const constant
int* lhs_mask_strides =
175 mask_strides + (has_output_mask ? 2 : 0);
176 const constant
int* rhs_mask_strides =
177 lhs_mask_strides + (has_operand_mask ? 2 : 0);
179 const int out_mask_offset = !has_output_mask
181 : tid_y * out_mask_strides[1] + tid_x * out_mask_strides[0];
182 int lhs_mask_offset = !has_operand_mask ? 0 : tid_y * lhs_mask_strides[1];
183 int rhs_mask_offset = !has_operand_mask ? 0 : tid_x * rhs_mask_strides[0];
184 const int lhs_mask_step = !has_operand_mask ? 0 : lhs_mask_strides[0];
185 const int rhs_mask_step = !has_operand_mask ? 0 : rhs_mask_strides[1];
186 short k_factor_cnt = k_mask_factor;
192 if (has_output_mask) {
193 auto mask_out = out_mask[out_mask_offset];
195 if (has_mul_output_mask) {
196 out_mask_op.
scale = float(mask_out);
201 constexpr short tgp_size = WM * WN * 32;
202 constexpr short vec_size = 4;
205 constexpr short TN = BN / vec_size;
206 constexpr short TM = tgp_size / TN;
208 const short thread_idx = simd_group_id * 32 + simd_lane_id;
209 const short bi = thread_idx / TN;
210 const short bj = vec_size * (thread_idx % TN);
212 D += bi * params->ldd + bj;
214 short tgp_bm =
min(BM, params->M - c_row);
215 short tgp_bn =
min(BN, params->N - c_col);
217 if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
218 for (
short ti = 0; ti < BM; ti += TM) {
220 for (
short j = 0; j < vec_size; j++) {
221 D[ti * params->ldd + j] = T(0.);
225 short jmax = tgp_bn - bj;
226 jmax = jmax < vec_size ? jmax : vec_size;
227 for (
short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
228 for (
short j = 0; j < jmax; j++) {
229 D[ti * params->ldd + j] = T(0.);
238 threadgroup_barrier(mem_flags::mem_none);
241 thread
typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
243 threadgroup T As[gemm_kernel::tgp_mem_size_a];
244 threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
247 thread
typename gemm_kernel::loader_a_t loader_a(
248 A, params->lda, As, simd_group_id, simd_lane_id);
249 thread
typename gemm_kernel::loader_b_t loader_b(
250 B, params->ldb, Bs, simd_group_id, simd_lane_id);
254 MN_aligned ? short(BM) : short(
min(BM, params->M - c_row));
256 MN_aligned ? short(BN) : short(
min(BN, params->N - c_col));
258 int gemm_k_iterations = params->gemm_k_iterations_aligned;
263 const int k_last = params->gemm_k_iterations_aligned * BK;
264 const int mask_idx_last = k_last / BM;
266 if (!has_operand_mask ||
267 (
bool(lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]) &&
268 bool(rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]))) {
269 if (has_mul_operand_mask) {
271 lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step];
273 rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step];
277 const int k_remain = params->K - k_last;
278 const size_t k_jump_a =
279 transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
280 const size_t k_jump_b =
281 transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
283 loader_a.src += k_jump_a;
284 loader_b.src += k_jump_b;
287 const short2 tile_dims_A =
288 transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
289 const short2 tile_dims_B =
290 transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
292 loader_a.load_safe(tile_dims_A);
293 loader_b.load_safe(tile_dims_B);
295 if (has_mul_operand_mask) {
296 loader_a.apply_inplace_op(lhs_mask_op);
297 loader_b.apply_inplace_op(rhs_mask_op);
300 threadgroup_barrier(mem_flags::mem_threadgroup);
306 loader_a.src -= k_jump_a;
307 loader_b.src -= k_jump_b;
314 for (; gemm_k_iterations > 0; gemm_k_iterations--) {
315 threadgroup_barrier(mem_flags::mem_threadgroup);
317 if (!has_operand_mask ||
318 (
bool(lhs_mask[lhs_mask_offset]) &&
319 bool(rhs_mask[rhs_mask_offset]))) {
320 if (has_mul_operand_mask) {
321 lhs_mask_op.
scale = lhs_mask[lhs_mask_offset];
322 rhs_mask_op.
scale = rhs_mask[rhs_mask_offset];
326 loader_a.load_unsafe();
327 loader_b.load_unsafe();
329 if (has_mul_operand_mask) {
330 loader_a.apply_inplace_op(lhs_mask_op);
331 loader_b.apply_inplace_op(rhs_mask_op);
334 threadgroup_barrier(mem_flags::mem_threadgroup);
345 lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
346 rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
347 k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
350 if (has_mul_output_mask) {
351 mma_op.apply_epilogue(out_mask_op);
355 mma_op.store_result(D, params->ldd);
362 const bool M_aligned = (tgp_bm == BM);
363 const bool N_aligned = (tgp_bn == BN);
365 const short2 tile_dims_A =
366 transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
367 const short2 tile_dims_B =
368 transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
370 for (; gemm_k_iterations > 0; gemm_k_iterations--) {
371 threadgroup_barrier(mem_flags::mem_threadgroup);
372 if (!has_operand_mask ||
373 (
bool(lhs_mask[lhs_mask_offset]) &&
374 bool(rhs_mask[rhs_mask_offset]))) {
375 if (has_mul_operand_mask) {
376 lhs_mask_op.
scale = lhs_mask[lhs_mask_offset];
377 rhs_mask_op.
scale = rhs_mask[rhs_mask_offset];
382 loader_a.load_unsafe();
384 loader_a.load_safe(tile_dims_A);
388 loader_b.load_unsafe();
390 loader_b.load_safe(tile_dims_B);
393 if (has_mul_operand_mask) {
394 loader_a.apply_inplace_op(lhs_mask_op);
395 loader_b.apply_inplace_op(rhs_mask_op);
398 threadgroup_barrier(mem_flags::mem_threadgroup);
409 lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
410 rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
411 k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
414 if (has_mul_output_mask) {
415 mma_op.apply_epilogue(out_mask_op);
418 if (M_aligned && N_aligned) {
419 mma_op.store_result(D, params->ldd);
421 mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
440 const device T* A [[buffer(0)]],
441 const device T* B [[buffer(1)]],
442 device T* D [[buffer(3)]],
443 const constant
GEMMParams* params [[buffer(4)]],
444 const constant
int* batch_shape [[buffer(6)]],
445 const constant
size_t* batch_strides [[buffer(7)]],
446 const device
bool* out_mask [[buffer(10)]],
447 const device
bool* lhs_mask [[buffer(11)]],
448 const device
bool* rhs_mask [[buffer(12)]],
449 const constant
int* mask_strides [[buffer(13)]],
450 uint simd_lane_id [[thread_index_in_simdgroup]],
451 uint simd_group_id [[simdgroup_index_in_threadgroup]],
452 uint3 tid [[threadgroup_position_in_grid]],
453 uint3 lid [[thread_position_in_threadgroup]]) {
470 const int tid_y = ((tid.y) << params->swizzle_log) +
471 ((tid.x) & ((1 << params->swizzle_log) - 1));
472 const int tid_x = (tid.x) >> params->swizzle_log;
474 if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
478 if (params->batch_ndim > 1) {
479 const constant
size_t* mask_batch_strides =
480 batch_strides + 2 * params->batch_ndim;
482 elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
484 if (has_operand_mask) {
485 const constant
size_t* mask_strides_lhs =
486 mask_batch_strides + params->batch_ndim;
487 const constant
size_t* mask_strides_rhs =
488 mask_strides_lhs + params->batch_ndim;
497 lhs_mask += batch_offsets.x;
498 rhs_mask += batch_offsets.y;
501 out_mask += tid.z * batch_strides[2 * params->batch_ndim];
502 if (has_operand_mask) {
503 lhs_mask += tid.z * batch_strides[3 * params->batch_ndim];
504 rhs_mask += tid.z * batch_strides[4 * params->batch_ndim];
509 if (params->batch_ndim > 1) {
510 const constant
size_t* A_bstrides = batch_strides;
511 const constant
size_t* B_bstrides = batch_strides + params->batch_ndim;
514 tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
516 A += batch_offsets.x;
517 B += batch_offsets.y;
520 A += params->batch_stride_a * tid.z;
521 B += params->batch_stride_b * tid.z;
524 D += params->batch_stride_d * tid.z;
527 const int c_row = tid_y * BM;
528 const int c_col = tid_x * BN;
529 const size_t c_row_long = size_t(c_row);
530 const size_t c_col_long = size_t(c_col);
532 A += transpose_a ? c_row_long : c_row_long * params->lda;
533 B += transpose_b ? c_col_long * params->ldb : c_col_long;
534 D += c_row_long * params->ldd + c_col_long;
536 bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];
540 constexpr short tgp_size = WM * WN * 32;
541 constexpr short vec_size = 4;
544 constexpr short TN = BN / vec_size;
545 constexpr short TM = tgp_size / TN;
547 const short thread_idx = simd_group_id * 32 + simd_lane_id;
548 const short bi = thread_idx / TN;
549 const short bj = vec_size * (thread_idx % TN);
551 D += bi * params->ldd + bj;
553 short tgp_bm =
min(BM, params->M - c_row);
554 short tgp_bn =
min(BN, params->N - c_col);
556 if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
557 for (
short ti = 0; ti < BM; ti += TM) {
559 for (
short j = 0; j < vec_size; j++) {
560 D[ti * params->ldd + j] = T(0.);
564 short jmax = tgp_bn - bj;
565 jmax = jmax < vec_size ? jmax : vec_size;
566 for (
short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
567 for (
short j = 0; j < jmax; j++) {
568 D[ti * params->ldd + j] = T(0.);
576 threadgroup_barrier(mem_flags::mem_none);
579 thread
typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
581 int gemm_k_iterations = params->gemm_k_iterations_aligned;
583 threadgroup T As[gemm_kernel::tgp_mem_size_a];
584 threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
587 thread
typename gemm_kernel::loader_a_t loader_a(
588 A, params->lda, As, simd_group_id, simd_lane_id);
589 thread
typename gemm_kernel::loader_b_t loader_b(
590 B, params->ldb, Bs, simd_group_id, simd_lane_id);
595 for (
int k = 0; k < gemm_k_iterations; k++) {
596 threadgroup_barrier(mem_flags::mem_threadgroup);
598 if (!has_operand_mask ||
600 [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
602 [((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
604 loader_a.load_unsafe();
605 loader_b.load_unsafe();
607 threadgroup_barrier(mem_flags::mem_threadgroup);
618 threadgroup_barrier(mem_flags::mem_none);
622 if (!has_operand_mask ||
624 [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
626 [(params->K / BM) * mask_strides[5] +
627 tid_x * mask_strides[4]])) {
628 int lbk = params->K - params->gemm_k_iterations_aligned * BK;
629 short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
630 short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
632 loader_a.load_safe(tile_dims_A);
633 loader_b.load_safe(tile_dims_B);
635 threadgroup_barrier(mem_flags::mem_threadgroup);
642 mma_op.store_result(D, params->ldd);
649 short tgp_bm =
min(BM, params->M - c_row);
650 short tgp_bn =
min(BN, params->N - c_col);
651 short lbk = params->K - params->gemm_k_iterations_aligned * BK;
653 bool M_aligned = (tgp_bm == BM);
654 bool N_aligned = (tgp_bn == BN);
656 short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
657 short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
659 for (
int k = 0; k < gemm_k_iterations; k++) {
660 threadgroup_barrier(mem_flags::mem_threadgroup);
661 if (!has_operand_mask ||
663 [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
665 [((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
668 loader_a.load_unsafe();
670 loader_a.load_safe(tile_dims_A);
674 loader_b.load_unsafe();
676 loader_b.load_safe(tile_dims_B);
679 threadgroup_barrier(mem_flags::mem_threadgroup);
691 threadgroup_barrier(mem_flags::mem_threadgroup);
693 if (!has_operand_mask ||
695 [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
697 [(params->K / BM) * mask_strides[5] +
698 tid_x * mask_strides[4]])) {
699 short2 tile_dims_A_last =
700 transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
701 short2 tile_dims_B_last =
702 transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
704 loader_a.load_safe(tile_dims_A_last);
705 loader_b.load_safe(tile_dims_B_last);
707 threadgroup_barrier(mem_flags::mem_threadgroup);
713 if (M_aligned && N_aligned) {
714 mma_op.store_result(D, params->ldd);
716 mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));