MLX
Loading...
Searching...
No Matches
steel_gemm_masked.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
4using namespace metal;
5using namespace mlx::steel;
6
8// GEMM kernels
10
11struct _NoMask {
12 char x;
13
14 constexpr METAL_FUNC operator bool() {
15 return true;
16 }
17 constexpr METAL_FUNC operator bool() const threadgroup {
18 return true;
19 }
20 constexpr METAL_FUNC operator bool() const device {
21 return true;
22 }
23 constexpr METAL_FUNC operator bool() const constant {
24 return true;
25 }
26};
27
28template <typename OutT, typename InT = OutT>
29struct ScaleOp {
30 OutT scale;
31
32 METAL_FUNC OutT apply(InT x) const {
33 return static_cast<OutT>(x) * scale;
34 }
35};
36
37typedef struct _NoMask nomask_t;
38
39template <
40 typename T,
41 typename out_mask_t,
42 typename op_mask_t,
43 int BM,
44 int BN,
45 int BK,
46 int WM,
47 int WN,
48 bool transpose_a,
49 bool transpose_b,
50 bool MN_aligned,
51 bool K_aligned>
52[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
54 const device T* A [[buffer(0)]],
55 const device T* B [[buffer(1)]],
56 device T* D [[buffer(3)]],
57 const constant GEMMParams* params [[buffer(4)]],
58 const constant int* batch_shape [[buffer(6)]],
59 const constant size_t* batch_strides [[buffer(7)]],
60 const device out_mask_t* out_mask [[buffer(10)]],
61 const device op_mask_t* lhs_mask [[buffer(11)]],
62 const device op_mask_t* rhs_mask [[buffer(12)]],
63 const constant int* mask_strides [[buffer(13)]],
64 uint simd_lane_id [[thread_index_in_simdgroup]],
65 uint simd_group_id [[simdgroup_index_in_threadgroup]],
66 uint3 tid [[threadgroup_position_in_grid]],
67 uint3 lid [[thread_position_in_threadgroup]]) {
68 // Appease the compiler
69 (void)lid;
70
71 static_assert(
72 BM == BN,
73 "block_masked_gemm must have the same block M and block N size");
74 static_assert(BM % BK == 0, "block_masked_gemm must have BM % BK == 0");
75
76 constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
77 constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
78
79 constexpr bool has_mul_operand_mask =
80 has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
81 constexpr bool has_mul_output_mask =
82 has_output_mask && !metal::is_same_v<out_mask_t, bool>;
83
84 constexpr short k_mask_factor = short(BM / BK);
85
86 using gemm_kernel = GEMMKernel<
87 T,
88 T,
89 BM,
90 BN,
91 BK,
92 WM,
93 WN,
94 transpose_a,
95 transpose_b,
96 MN_aligned,
97 K_aligned>;
98
99 const int tid_y = ((tid.y) << params->swizzle_log) +
100 ((tid.x) & ((1 << params->swizzle_log) - 1));
101 const int tid_x = (tid.x) >> params->swizzle_log;
102
103 if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
104 return;
105 }
106
107 const constant size_t* mask_batch_strides =
108 batch_strides + 2 * params->batch_ndim;
109
110 if (params->batch_ndim > 1) {
111 if (has_output_mask) {
112 out_mask += elem_to_loc(
113 tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
114
115 mask_batch_strides += params->batch_ndim;
116 }
117
118 if (has_operand_mask) {
119 const constant size_t* mask_strides_lhs = mask_batch_strides;
120 const constant size_t* mask_strides_rhs =
121 mask_strides_lhs + params->batch_ndim;
122
123 ulong2 batch_offsets = elem_to_loc_broadcast(
124 tid.z,
125 batch_shape,
126 mask_strides_lhs,
127 mask_strides_rhs,
128 params->batch_ndim);
129
130 lhs_mask += batch_offsets.x;
131 rhs_mask += batch_offsets.y;
132 }
133 } else {
134 if (has_output_mask) {
135 out_mask += tid.z * mask_batch_strides[0];
136 mask_batch_strides += params->batch_ndim;
137 }
138
139 if (has_operand_mask) {
140 lhs_mask += tid.z * mask_batch_strides[0];
141 rhs_mask += tid.z * mask_batch_strides[params->batch_ndim];
142 }
143 }
144
145 // Adjust for batch
146 if (params->batch_ndim > 1) {
147 const constant size_t* A_bstrides = batch_strides;
148 const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
149
150 ulong2 batch_offsets = elem_to_loc_broadcast(
151 tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
152
153 A += batch_offsets.x;
154 B += batch_offsets.y;
155
156 } else {
157 A += params->batch_stride_a * tid.z;
158 B += params->batch_stride_b * tid.z;
159 }
160
161 D += params->batch_stride_d * tid.z;
162
163 // Find block in A, B, C
164 const int c_row = tid_y * BM;
165 const int c_col = tid_x * BN;
166 const size_t c_row_long = size_t(c_row);
167 const size_t c_col_long = size_t(c_col);
168
169 A += transpose_a ? c_row_long : c_row_long * params->lda;
170 B += transpose_b ? c_col_long * params->ldb : c_col_long;
171 D += c_row_long * params->ldd + c_col_long;
172
173 const constant int* out_mask_strides = mask_strides;
174 const constant int* lhs_mask_strides =
175 mask_strides + (has_output_mask ? 2 : 0);
176 const constant int* rhs_mask_strides =
177 lhs_mask_strides + (has_operand_mask ? 2 : 0);
178
179 const int out_mask_offset = !has_output_mask
180 ? 0
181 : tid_y * out_mask_strides[1] + tid_x * out_mask_strides[0];
182 int lhs_mask_offset = !has_operand_mask ? 0 : tid_y * lhs_mask_strides[1];
183 int rhs_mask_offset = !has_operand_mask ? 0 : tid_x * rhs_mask_strides[0];
184 const int lhs_mask_step = !has_operand_mask ? 0 : lhs_mask_strides[0];
185 const int rhs_mask_step = !has_operand_mask ? 0 : rhs_mask_strides[1];
186 short k_factor_cnt = k_mask_factor;
187
188 ScaleOp<float> out_mask_op;
189 ScaleOp<T> lhs_mask_op;
190 ScaleOp<T> rhs_mask_op;
191
192 if (has_output_mask) {
193 auto mask_out = out_mask[out_mask_offset];
194
195 if (has_mul_output_mask) {
196 out_mask_op.scale = float(mask_out);
197 }
198
199 // Write zeros and return
200 if (!mask_out) {
201 constexpr short tgp_size = WM * WN * 32;
202 constexpr short vec_size = 4;
203
204 // Tile threads in threadgroup
205 constexpr short TN = BN / vec_size;
206 constexpr short TM = tgp_size / TN;
207
208 const short thread_idx = simd_group_id * 32 + simd_lane_id;
209 const short bi = thread_idx / TN;
210 const short bj = vec_size * (thread_idx % TN);
211
212 D += bi * params->ldd + bj;
213
214 short tgp_bm = min(BM, params->M - c_row);
215 short tgp_bn = min(BN, params->N - c_col);
216
217 if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
218 for (short ti = 0; ti < BM; ti += TM) {
220 for (short j = 0; j < vec_size; j++) {
221 D[ti * params->ldd + j] = T(0.);
222 }
223 }
224 } else {
225 short jmax = tgp_bn - bj;
226 jmax = jmax < vec_size ? jmax : vec_size;
227 for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
228 for (short j = 0; j < jmax; j++) {
229 D[ti * params->ldd + j] = T(0.);
230 }
231 }
232 }
233
234 return;
235 }
236 }
237
238 threadgroup_barrier(mem_flags::mem_none);
239
240 // Prepare threadgroup mma operation
241 thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
242
243 threadgroup T As[gemm_kernel::tgp_mem_size_a];
244 threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
245
246 // Prepare threadgroup loading operations
247 thread typename gemm_kernel::loader_a_t loader_a(
248 A, params->lda, As, simd_group_id, simd_lane_id);
249 thread typename gemm_kernel::loader_b_t loader_b(
250 B, params->ldb, Bs, simd_group_id, simd_lane_id);
251
252 // Prepare threadgroup bounds
253 const short tgp_bm =
254 MN_aligned ? short(BM) : short(min(BM, params->M - c_row));
255 const short tgp_bn =
256 MN_aligned ? short(BN) : short(min(BN, params->N - c_col));
257
258 int gemm_k_iterations = params->gemm_k_iterations_aligned;
259
261 // Do unaligned K iterations first
262 if (!K_aligned) {
263 const int k_last = params->gemm_k_iterations_aligned * BK;
264 const int mask_idx_last = k_last / BM;
265
266 if (!has_operand_mask ||
267 (bool(lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]) &&
268 bool(rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]))) {
269 if (has_mul_operand_mask) {
270 lhs_mask_op.scale =
271 lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step];
272 rhs_mask_op.scale =
273 rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step];
274 }
275
276 // Move loader source ahead to end
277 const int k_remain = params->K - k_last;
278 const size_t k_jump_a =
279 transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
280 const size_t k_jump_b =
281 transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
282
283 loader_a.src += k_jump_a;
284 loader_b.src += k_jump_b;
285
286 // Load tile
287 const short2 tile_dims_A =
288 transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
289 const short2 tile_dims_B =
290 transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
291
292 loader_a.load_safe(tile_dims_A);
293 loader_b.load_safe(tile_dims_B);
294
295 if (has_mul_operand_mask) {
296 loader_a.apply_inplace_op(lhs_mask_op);
297 loader_b.apply_inplace_op(rhs_mask_op);
298 }
299
300 threadgroup_barrier(mem_flags::mem_threadgroup);
301
302 // Do matmul
303 mma_op.mma(As, Bs);
304
305 // Reset source back to start
306 loader_a.src -= k_jump_a;
307 loader_b.src -= k_jump_b;
308 }
309 }
310
312 // MNK aligned loop
313 if (MN_aligned) {
314 for (; gemm_k_iterations > 0; gemm_k_iterations--) {
315 threadgroup_barrier(mem_flags::mem_threadgroup);
316
317 if (!has_operand_mask ||
318 (bool(lhs_mask[lhs_mask_offset]) &&
319 bool(rhs_mask[rhs_mask_offset]))) {
320 if (has_mul_operand_mask) {
321 lhs_mask_op.scale = lhs_mask[lhs_mask_offset];
322 rhs_mask_op.scale = rhs_mask[rhs_mask_offset];
323 }
324
325 // Load elements into threadgroup
326 loader_a.load_unsafe();
327 loader_b.load_unsafe();
328
329 if (has_mul_operand_mask) {
330 loader_a.apply_inplace_op(lhs_mask_op);
331 loader_b.apply_inplace_op(rhs_mask_op);
332 }
333
334 threadgroup_barrier(mem_flags::mem_threadgroup);
335
336 // Multiply and accumulate threadgroup elements
337 mma_op.mma(As, Bs);
338 }
339
340 // Prepare for next iteration
341 loader_a.next();
342 loader_b.next();
343
344 k_factor_cnt--;
345 lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
346 rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
347 k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
348 }
349
350 if (has_mul_output_mask) {
351 mma_op.apply_epilogue(out_mask_op);
352 }
353
354 // Store results to device memory
355 mma_op.store_result(D, params->ldd);
356 return;
357
358 }
360 // MN unaligned loop
361 else {
362 const bool M_aligned = (tgp_bm == BM);
363 const bool N_aligned = (tgp_bn == BN);
364
365 const short2 tile_dims_A =
366 transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
367 const short2 tile_dims_B =
368 transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
369
370 for (; gemm_k_iterations > 0; gemm_k_iterations--) {
371 threadgroup_barrier(mem_flags::mem_threadgroup);
372 if (!has_operand_mask ||
373 (bool(lhs_mask[lhs_mask_offset]) &&
374 bool(rhs_mask[rhs_mask_offset]))) {
375 if (has_mul_operand_mask) {
376 lhs_mask_op.scale = lhs_mask[lhs_mask_offset];
377 rhs_mask_op.scale = rhs_mask[rhs_mask_offset];
378 }
379
380 // Load elements into threadgroup
381 if (M_aligned) {
382 loader_a.load_unsafe();
383 } else {
384 loader_a.load_safe(tile_dims_A);
385 }
386
387 if (N_aligned) {
388 loader_b.load_unsafe();
389 } else {
390 loader_b.load_safe(tile_dims_B);
391 }
392
393 if (has_mul_operand_mask) {
394 loader_a.apply_inplace_op(lhs_mask_op);
395 loader_b.apply_inplace_op(rhs_mask_op);
396 }
397
398 threadgroup_barrier(mem_flags::mem_threadgroup);
399
400 // Multiply and accumulate threadgroup elements
401 mma_op.mma(As, Bs);
402 }
403
404 // Prepare for next iteration
405 loader_a.next();
406 loader_b.next();
407
408 k_factor_cnt--;
409 lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
410 rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
411 k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
412 }
413
414 if (has_mul_output_mask) {
415 mma_op.apply_epilogue(out_mask_op);
416 }
417
418 if (M_aligned && N_aligned) {
419 mma_op.store_result(D, params->ldd);
420 } else {
421 mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
422 }
423 }
424}
425
426template <
427 typename T,
428 int BM,
429 int BN,
430 int BK,
431 int WM,
432 int WN,
433 bool transpose_a,
434 bool transpose_b,
435 bool MN_aligned,
436 bool K_aligned,
437 bool has_operand_mask = false>
438[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
440 const device T* A [[buffer(0)]],
441 const device T* B [[buffer(1)]],
442 device T* D [[buffer(3)]],
443 const constant GEMMParams* params [[buffer(4)]],
444 const constant int* batch_shape [[buffer(6)]],
445 const constant size_t* batch_strides [[buffer(7)]],
446 const device bool* out_mask [[buffer(10)]],
447 const device bool* lhs_mask [[buffer(11)]],
448 const device bool* rhs_mask [[buffer(12)]],
449 const constant int* mask_strides [[buffer(13)]],
450 uint simd_lane_id [[thread_index_in_simdgroup]],
451 uint simd_group_id [[simdgroup_index_in_threadgroup]],
452 uint3 tid [[threadgroup_position_in_grid]],
453 uint3 lid [[thread_position_in_threadgroup]]) {
454 // Appease the compiler
455 (void)lid;
456
457 using gemm_kernel = GEMMKernel<
458 T,
459 T,
460 BM,
461 BN,
462 BK,
463 WM,
464 WN,
465 transpose_a,
466 transpose_b,
467 MN_aligned,
468 K_aligned>;
469
470 const int tid_y = ((tid.y) << params->swizzle_log) +
471 ((tid.x) & ((1 << params->swizzle_log) - 1));
472 const int tid_x = (tid.x) >> params->swizzle_log;
473
474 if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
475 return;
476 }
477
478 if (params->batch_ndim > 1) {
479 const constant size_t* mask_batch_strides =
480 batch_strides + 2 * params->batch_ndim;
481 out_mask +=
482 elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
483
484 if (has_operand_mask) {
485 const constant size_t* mask_strides_lhs =
486 mask_batch_strides + params->batch_ndim;
487 const constant size_t* mask_strides_rhs =
488 mask_strides_lhs + params->batch_ndim;
489
490 ulong2 batch_offsets = elem_to_loc_broadcast(
491 tid.z,
492 batch_shape,
493 mask_strides_lhs,
494 mask_strides_rhs,
495 params->batch_ndim);
496
497 lhs_mask += batch_offsets.x;
498 rhs_mask += batch_offsets.y;
499 }
500 } else {
501 out_mask += tid.z * batch_strides[2 * params->batch_ndim];
502 if (has_operand_mask) {
503 lhs_mask += tid.z * batch_strides[3 * params->batch_ndim];
504 rhs_mask += tid.z * batch_strides[4 * params->batch_ndim];
505 }
506 }
507
508 // Adjust for batch
509 if (params->batch_ndim > 1) {
510 const constant size_t* A_bstrides = batch_strides;
511 const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
512
513 ulong2 batch_offsets = elem_to_loc_broadcast(
514 tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
515
516 A += batch_offsets.x;
517 B += batch_offsets.y;
518
519 } else {
520 A += params->batch_stride_a * tid.z;
521 B += params->batch_stride_b * tid.z;
522 }
523
524 D += params->batch_stride_d * tid.z;
525
526 // Find block in A, B, C
527 const int c_row = tid_y * BM;
528 const int c_col = tid_x * BN;
529 const size_t c_row_long = size_t(c_row);
530 const size_t c_col_long = size_t(c_col);
531
532 A += transpose_a ? c_row_long : c_row_long * params->lda;
533 B += transpose_b ? c_col_long * params->ldb : c_col_long;
534 D += c_row_long * params->ldd + c_col_long;
535
536 bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];
537
538 // Write zeros and return
539 if (!mask_out) {
540 constexpr short tgp_size = WM * WN * 32;
541 constexpr short vec_size = 4;
542
543 // Tile threads in threadgroup
544 constexpr short TN = BN / vec_size;
545 constexpr short TM = tgp_size / TN;
546
547 const short thread_idx = simd_group_id * 32 + simd_lane_id;
548 const short bi = thread_idx / TN;
549 const short bj = vec_size * (thread_idx % TN);
550
551 D += bi * params->ldd + bj;
552
553 short tgp_bm = min(BM, params->M - c_row);
554 short tgp_bn = min(BN, params->N - c_col);
555
556 if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
557 for (short ti = 0; ti < BM; ti += TM) {
559 for (short j = 0; j < vec_size; j++) {
560 D[ti * params->ldd + j] = T(0.);
561 }
562 }
563 } else {
564 short jmax = tgp_bn - bj;
565 jmax = jmax < vec_size ? jmax : vec_size;
566 for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
567 for (short j = 0; j < jmax; j++) {
568 D[ti * params->ldd + j] = T(0.);
569 }
570 }
571 }
572
573 return;
574 }
575
576 threadgroup_barrier(mem_flags::mem_none);
577
578 // Prepare threadgroup mma operation
579 thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
580
581 int gemm_k_iterations = params->gemm_k_iterations_aligned;
582
583 threadgroup T As[gemm_kernel::tgp_mem_size_a];
584 threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
585
586 // Prepare threadgroup loading operations
587 thread typename gemm_kernel::loader_a_t loader_a(
588 A, params->lda, As, simd_group_id, simd_lane_id);
589 thread typename gemm_kernel::loader_b_t loader_b(
590 B, params->ldb, Bs, simd_group_id, simd_lane_id);
591
593 // MNK aligned loop
594 if (MN_aligned) {
595 for (int k = 0; k < gemm_k_iterations; k++) {
596 threadgroup_barrier(mem_flags::mem_threadgroup);
597
598 if (!has_operand_mask ||
599 (lhs_mask
600 [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
601 rhs_mask
602 [((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
603 // Load elements into threadgroup
604 loader_a.load_unsafe();
605 loader_b.load_unsafe();
606
607 threadgroup_barrier(mem_flags::mem_threadgroup);
608
609 // Multiply and accumulate threadgroup elements
610 mma_op.mma(As, Bs);
611 }
612
613 // Prepare for next iteration
614 loader_a.next();
615 loader_b.next();
616 }
617
618 threadgroup_barrier(mem_flags::mem_none);
619
620 // Loop tail
621 if (!K_aligned) {
622 if (!has_operand_mask ||
623 (lhs_mask
624 [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
625 rhs_mask
626 [(params->K / BM) * mask_strides[5] +
627 tid_x * mask_strides[4]])) {
628 int lbk = params->K - params->gemm_k_iterations_aligned * BK;
629 short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
630 short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
631
632 loader_a.load_safe(tile_dims_A);
633 loader_b.load_safe(tile_dims_B);
634
635 threadgroup_barrier(mem_flags::mem_threadgroup);
636
637 mma_op.mma(As, Bs);
638 }
639 }
640
641 // Store results to device memory
642 mma_op.store_result(D, params->ldd);
643 return;
644
645 }
647 // MN unaligned loop
648 else { // Loop over K - unaligned case
649 short tgp_bm = min(BM, params->M - c_row);
650 short tgp_bn = min(BN, params->N - c_col);
651 short lbk = params->K - params->gemm_k_iterations_aligned * BK;
652
653 bool M_aligned = (tgp_bm == BM);
654 bool N_aligned = (tgp_bn == BN);
655
656 short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
657 short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
658
659 for (int k = 0; k < gemm_k_iterations; k++) {
660 threadgroup_barrier(mem_flags::mem_threadgroup);
661 if (!has_operand_mask ||
662 (lhs_mask
663 [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
664 rhs_mask
665 [((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
666 // Load elements into threadgroup
667 if (M_aligned) {
668 loader_a.load_unsafe();
669 } else {
670 loader_a.load_safe(tile_dims_A);
671 }
672
673 if (N_aligned) {
674 loader_b.load_unsafe();
675 } else {
676 loader_b.load_safe(tile_dims_B);
677 }
678
679 threadgroup_barrier(mem_flags::mem_threadgroup);
680
681 // Multiply and accumulate threadgroup elements
682 mma_op.mma(As, Bs);
683 }
684
685 // Prepare for next iteration
686 loader_a.next();
687 loader_b.next();
688 }
689
690 if (!K_aligned) {
691 threadgroup_barrier(mem_flags::mem_threadgroup);
692
693 if (!has_operand_mask ||
694 (lhs_mask
695 [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
696 rhs_mask
697 [(params->K / BM) * mask_strides[5] +
698 tid_x * mask_strides[4]])) {
699 short2 tile_dims_A_last =
700 transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
701 short2 tile_dims_B_last =
702 transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
703
704 loader_a.load_safe(tile_dims_A_last);
705 loader_b.load_safe(tile_dims_B_last);
706
707 threadgroup_barrier(mem_flags::mem_threadgroup);
708
709 mma_op.mma(As, Bs);
710 }
711 }
712
713 if (M_aligned && N_aligned) {
714 mma_op.store_result(D, params->ldd);
715 } else {
716 mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
717 }
718 }
719}
METAL_FUNC ulong2 elem_to_loc_broadcast(uint elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, int ndim)
Definition utils.h:7
METAL_FUNC stride_t elem_to_loc(uint elem, device const int *shape, device const stride_t *strides, int ndim)
Definition utils.h:77
Definition bf16.h:265
METAL_FUNC bfloat16_t min(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:234
Definition loader_channel_l.h:14
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
void block_masked_gemm(const device T *A, const device T *B, device T *D, const constant GEMMParams *params, const constant int *batch_shape, const constant size_t *batch_strides, const device out_mask_t *out_mask, const device op_mask_t *lhs_mask, const device op_mask_t *rhs_mask, const constant int *mask_strides, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 lid)
Definition steel_gemm_masked.h:53
Definition steel_gemm_masked.h:11
char x
Definition steel_gemm_masked.h:12
Definition steel_gemm_masked.h:29
OutT scale
Definition steel_gemm_masked.h:30
METAL_FUNC OutT apply(InT x) const
Definition steel_gemm_masked.h:32
Definition gemm.h:37
Definition params.h:12