MLX
Loading...
Searching...
No Matches
mma.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
5#include <metal_simdgroup>
6#include <metal_simdgroup_matrix>
7#include <metal_stdlib>
8
12
13using namespace metal;
14
16// MMA helper
18
19namespace mlx {
20namespace steel {
21
22template <typename RInt, typename CInt>
23struct Shape2D {
24 RInt r;
25 CInt c;
26
27 Shape2D(RInt r_, CInt c_) : r(r_), c(c_) {}
28};
29
30template <typename Shape, typename Layout>
31struct Layout2D {
32 Shape shape;
33 Layout layout;
34};
35
36template <typename T, int kFragRows_, int kFragCols_>
37struct BaseMMAFrag {
38 static_assert(
39 kFragRows_ == 8,
40 "Only 8 x 8 fragment matrices are currently supported");
41 static_assert(
42 kFragCols_ == 8,
43 "Only 8 x 8 fragment matrices are currently supported");
44};
45
46template <typename T>
47struct BaseMMAFrag<T, 8, 8> {
48 STEEL_CONST int kFragRows = 8;
49 STEEL_CONST int kFragCols = 8;
50
51 STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32;
52
53 STEEL_CONST int kElemRows = 1;
54 STEEL_CONST int kElemCols = 2;
55
56 static_assert(
57 kElemRows * kElemCols == kElemsPerFrag,
58 "MMAFrag shape is not consistent with MMAFrag size");
59
60 typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;
61 typedef metal::vec<T, kElemsPerFrag> frag_type;
62 typedef metal::vec<T, kElemRows> row_frag_type;
63 typedef metal::vec<T, kElemCols> col_frag_type;
64
65 METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
66 [[thread_index_in_simdgroup]]) {
67 const short qid = simd_lane_id / 4;
68 const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);
69 const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
70 return short2{fn, fm};
71 }
72
73 template <typename SrcPtrType, typename StrX, typename StrY>
74 METAL_FUNC static constexpr void
75 load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) {
77 for (short i = 0; i < kElemRows; i++) {
79 for (short j = 0; j < kElemCols; j++) {
80 dst[i * kElemCols + j] = static_cast<T>(src[i * str_x + j * str_y]);
81 }
82 }
83 }
84
85 template <
86 typename SrcPtrType,
87 typename StrX,
88 typename StrY,
89 typename LimX,
90 typename LimY,
91 typename OffX,
92 typename OffY>
93 METAL_FUNC static constexpr void load_safe(
94 thread frag_type& dst,
95 SrcPtrType src,
96 StrX str_x,
97 StrY str_y,
98 LimX lim_x,
99 LimY lim_y,
100 OffX off_x = Int<0>{},
101 OffY off_y = Int<0>{}) {
103 for (short i = 0; i < kElemRows; i++) {
105 for (short j = 0; j < kElemCols; j++) {
106 if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
107 dst[i * kElemCols + j] =
108 static_cast<T>(src[(off_x + i) * str_x + (off_x + j) * str_y]);
109 } else {
110 dst[i * kElemCols + j] = T(0);
111 }
112 }
113 }
114 }
115
116 template <typename DstPtrType, typename StrX, typename StrY>
117 METAL_FUNC static constexpr void
118 store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) {
120
122 for (short i = 0; i < kElemRows; i++) {
124 for (short j = 0; j < kElemCols; j++) {
125 dst[i * str_x + j * str_y] = static_cast<U>(src[i * kElemCols + j]);
126 }
127 }
128 }
129
130 template <
131 typename DstPtrType,
132 typename StrX,
133 typename StrY,
134 typename LimX,
135 typename LimY,
136 typename OffX,
137 typename OffY>
138 METAL_FUNC static constexpr void store_safe(
139 const thread frag_type& src,
140 DstPtrType dst,
141 StrX str_x,
142 StrY str_y,
143 LimX lim_x,
144 LimY lim_y,
145 OffX off_x = Int<0>{},
146 OffY off_y = Int<0>{}) {
148
150 for (short i = 0; i < kElemRows; i++) {
152 for (short j = 0; j < kElemCols; j++) {
153 if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
154 dst[(off_x + i) * str_x + (off_y + j) * str_y] =
155 static_cast<U>(src[i * kElemCols + j]);
156 }
157 }
158 }
159 }
160
161 METAL_FUNC static constexpr void mma(
162 thread frag_type& D,
163 thread frag_type& A,
164 thread frag_type& B,
165 thread frag_type& C) {
166 mat_type D_mat;
167 mat_type A_mat;
168 mat_type B_mat;
169 mat_type C_mat;
170
171 reinterpret_cast<thread frag_type&>(A_mat.thread_elements()) = A;
172 reinterpret_cast<thread frag_type&>(B_mat.thread_elements()) = B;
173 reinterpret_cast<thread frag_type&>(C_mat.thread_elements()) = C;
174
175 mma(D_mat, A_mat, B_mat, C_mat);
176
177 D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements());
178 }
179
180 METAL_FUNC static constexpr void mma(
181 thread mat_type& D,
182 thread mat_type& A,
183 thread mat_type& B,
184 thread mat_type& C) {
185 simdgroup_multiply_accumulate(D, A, B, C);
186 }
187
188 template <typename Op>
189 METAL_FUNC static constexpr void row_reduce(
190 thread const frag_type& inp_vals,
191 thread T* reduced_vals) {
192 T thr_reduce = Op::apply(inp_vals.x, inp_vals.y);
193
194 T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1));
195 qgr_reduce = Op::apply(thr_reduce, qgr_reduce);
196
197 T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8));
198 sgr_reduce = Op::apply(qgr_reduce, sgr_reduce);
199
200 reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce);
201 }
202
203 template <typename Op>
204 METAL_FUNC static constexpr void row_bin_op(
205 thread frag_type& inp_vals,
206 thread T* row_vals) {
208 for (short i = 0; i < kElemRows; i++) {
210 for (short j = 0; j < kElemCols; j++) {
211 inp_vals[i * kElemCols + j] =
212 Op::apply(inp_vals[i * kElemCols + j], row_vals[i]);
213 }
214 }
215 }
216};
217
218template <
219 typename T,
220 int kTileRows_,
221 int kTileCols_,
222 class MMAFrag_ = BaseMMAFrag<T, 8, 8>>
223struct MMATile {
224 using MMAFrag_t = MMAFrag_;
225 using elem_type = T;
226 STEEL_CONST int kFragRows = MMAFrag_t::kFragRows;
227 STEEL_CONST int kFragCols = MMAFrag_t::kFragCols;
228 STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag;
229
230 STEEL_CONST int kTileRows = kTileRows_;
231 STEEL_CONST int kTileCols = kTileCols_;
232
235
238
239 STEEL_CONST int kRowsPerThread = kTileRows * MMAFrag_t::kElemRows;
240 STEEL_CONST int kColsPerThread = kTileCols * MMAFrag_t::kElemCols;
241
242 typedef typename MMAFrag_t::mat_type mat_type;
243 typedef typename MMAFrag_t::frag_type frag_type;
244
246
247 METAL_FUNC MMATile() thread {}
248
249 METAL_FUNC constexpr void clear() {
251 for (short i = 0; i < kNumFrags; ++i) {
252 val_frags[i] = frag_type(0);
253 }
254 }
255
256 METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {
257 return val_frags[i * kTileCols + j];
258 }
259
260 METAL_FUNC constexpr const thread frag_type& frag_at(
261 const short i,
262 const short j) const {
263 return val_frags[i * kTileCols + j];
264 }
265
266 METAL_FUNC mat_type mat_at(const short i, const short j) {
267 mat_type val_mat;
269 for (short ii = 0; ii < kElemsPerFrag; ++ii) {
270 val_mat.thread_elements()[ii] = frag_at(i, j)[ii];
271 }
272 return val_mat;
273 }
274
275 METAL_FUNC thread elem_type* elems() {
276 return reinterpret_cast<thread elem_type*>(val_frags);
277 }
278
279 METAL_FUNC const thread elem_type* elems() const {
280 return reinterpret_cast<const thread elem_type*>(val_frags);
281 }
282
283 template <typename Op>
284 METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const {
286 for (short i = 0; i < kTileRows; ++i) {
288 for (short j = 0; j < kTileCols; ++j) {
289 MMAFrag_t::template row_reduce<Op>(
290 frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
291 }
292 }
293 }
294
295 template <typename Op>
296 METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread]) {
298 for (short i = 0; i < kTileRows; ++i) {
300 for (short j = 0; j < kTileCols; ++j) {
301 MMAFrag_t::template row_bin_op<Op>(
302 frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
303 }
304 }
305 }
306
307 template <typename U, int w_x, int w_y, int str_x, int str_y>
308 METAL_FUNC void load(const threadgroup U* src) {
310 for (short i = 0; i < kTileRows; ++i) {
312 for (short j = 0; j < kTileCols; ++j) {
313 MMAFrag_t::load(
314 frag_at(i, j),
315 &(
316 src[(i * kFragRows) * w_x * str_x +
317 (j * kFragCols) * w_y * str_y]),
318 Int<str_x>{},
319 Int<str_y>{});
320 }
321 }
322 }
323
324 template <typename U, int w_x, int w_y, int str_x, int str_y>
325 METAL_FUNC void store(threadgroup U* dst) const {
327 for (short i = 0; i < kTileRows; ++i) {
329 for (short j = 0; j < kTileCols; ++j) {
330 MMAFrag_t::store(
331 frag_at(i, j),
332 &(
333 dst[(i * kFragRows) * w_x * str_x +
334 (j * kFragCols) * w_y * str_y]),
335 Int<str_x>{},
336 Int<str_y>{});
337 }
338 }
339 }
340
341 template <typename U, int w_x, int w_y>
342 METAL_FUNC void load(const device U* src, const int ld) {
344 for (short i = 0; i < kTileRows; ++i) {
346 for (short j = 0; j < kTileCols; ++j) {
347 MMAFrag_t::load(
348 frag_at(i, j),
349 &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
350 ld,
351 Int<1>{});
352 }
353 }
354 }
355
356 template <typename U, int w_x, int w_y>
357 METAL_FUNC void store(device U* dst, const int ld) const {
359 for (short i = 0; i < kTileRows; ++i) {
361 for (short j = 0; j < kTileCols; ++j) {
362 MMAFrag_t::store(
363 frag_at(i, j),
364 &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
365 ld,
366 Int<1>{});
367 }
368 }
369 }
370
371 template <typename U, int w_x, int w_y>
372 METAL_FUNC void
373 load_safe(const device U* src, const int ld, const short2 src_tile_dims) {
375 for (int i = 0; i < kTileRows; ++i) {
377 for (int j = 0; j < kTileCols; ++j) {
378 MMAFrag_t::load_safe(
379 frag_at(i, j),
380 src,
381 ld,
382 Int<1>{},
383 src_tile_dims.y,
384 src_tile_dims.x,
385 (i * kFragRows) * w_x,
386 (j * kFragCols) * w_y);
387 }
388 }
389 }
390
391 template <typename U, int w_x, int w_y>
392 METAL_FUNC void
393 store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {
395 for (int i = 0; i < kTileRows; ++i) {
397 for (int j = 0; j < kTileCols; ++j) {
398 MMAFrag_t::store_safe(
399 frag_at(i, j),
400 dst,
401 ld,
402 Int<1>{},
403 dst_tile_dims.y,
404 dst_tile_dims.x,
405 (i * kFragRows) * w_x,
406 (j * kFragCols) * w_y);
407 }
408 }
409 }
410};
411
412template <typename T, typename U, int M, int N, int K>
413METAL_FUNC void tile_matmad(
414 thread MMATile<T, M, N>& D,
415 thread MMATile<U, M, K>& A,
416 thread MMATile<U, K, N>& B,
417 thread MMATile<T, M, N>& C) {
419 for (short k = 0; k < K; ++k) {
421 for (short m = 0; m < M; ++m) {
423 for (short n = 0; n < N; ++n) {
424 short n_serp = (m % 2) ? (N - 1 - n) : n;
426 D.frag_at(m, n_serp),
427 A.frag_at(m, k),
428 B.frag_at(k, n_serp),
429 C.frag_at(m, n_serp));
430 }
431 }
432 }
433}
434
435template <
436 typename T,
437 typename U,
438 int BM,
439 int BN,
440 int BK,
441 int WM,
442 int WN,
443 bool transpose_a,
444 bool transpose_b,
445 short lda_tgp,
446 short ldb_tgp,
447 typename AccumType = float,
448 typename Epilogue = TransformNone<U, AccumType>>
449struct BlockMMA {
450 // MMAFrag size
453
454 // Warp tile simdgroup matrix strides along M
456 // Warp tile simdgroup matrix strides along M
458
459 // Warp tile size along M
461 // Warp tile size along N
463
464 // Threadgroup A strides
465 STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
466 STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
467
468 // Threadgroup B strides
469 STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
470 STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
471
472 // Threadgroup strides along K
475
476 // Simdgroup matrices
480
481 // Offsets within threadgroup
482 short sm;
483 short sn;
484
487
488 /* Constructor */
489 METAL_FUNC BlockMMA(
490 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
491 ushort simd_lane_id [[thread_index_in_simdgroup]]) {
492 // Determine thread position in simdgroup matrix
493 short tm = kFragSize * (simd_group_id / WN);
494 short tn = kFragSize * (simd_group_id % WN);
495
496 short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
497 sm = simd_coord.y;
498 sn = simd_coord.x;
499
500 // Determine thread and simdgroup offset
501 As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K
502 Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N
503
504 sm += tm;
505 sn += tn;
506 }
507
508 /* (BM, BK) X (BK, BN) multiply accumulate function */
509 METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
510 // Adjust for simdgroup and thread location
511 As += As_offset;
512 Bs += Bs_offset;
513
514 // Iterate over BK in blocks of kFragSize
516 for (short kk = 0; kk < BK; kk += kFragSize) {
517 simdgroup_barrier(mem_flags::mem_none);
518
519 Atile.template load<T, WM, 1, A_str_m, A_str_k>(As);
520
521 simdgroup_barrier(mem_flags::mem_none);
522
523 Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs);
524
525 simdgroup_barrier(mem_flags::mem_none);
526
528
529 // Progress to next simdgroup tile
530 As += tile_stride_a;
531 Bs += tile_stride_b;
532 }
533 }
534
535 /* Store results from simdgroup_matrix results into device memory */
536 METAL_FUNC void store_result(device U* D, const int ldd) {
537 // Apply epilogue
539 for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
540 Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
541 }
542
543 // Adjust for simdgroup and thread location
544 D += sm * ldd + sn;
545
546 Ctile.template store<U, WM, WN>(D, ldd);
547 }
548
549 METAL_FUNC void
550 store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
551 // Apply epilogue
553 for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
554 Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
555 }
556
557 // Adjust for simdgroup and thread location
558 D += sm * ldd + sn;
559 dst_tile_dims -= short2(sn, sm);
560
561 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
562 return;
563
564 Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims);
565 }
566
567 /* Apply epilogue */
568 template <typename UnaryEpilogue>
569 METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
570 // Loop over all simdgroup tiles
572 for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
573 Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]);
574 }
575 }
576
577 /* Apply epilogue */
578 template <typename BinaryEpilogue>
579 METAL_FUNC void apply_epilogue(
580 const device U* C,
581 const int ldc,
582 const int fdc,
583 thread const BinaryEpilogue& epilogue_op) {
584 // Adjust for simdgroup and thread location
585 C += (sm)*ldc + (sn)*fdc;
586
587 // Loop over all simdgroup tiles
589 for (short i = 0; i < TM; i++) {
591 for (short j = 0; j < TN; j++) {
592 // Get accumulated result and associated offset in C
593 thread auto& accum = Ctile.frag_at(i, j);
594 int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
595
596 // Apply epilogue
598 for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) {
599 accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
600 }
601 }
602 }
603 }
604
605 /* Apply epilogue */
606 template <typename BinaryEpilogue>
607 METAL_FUNC void apply_epilogue_safe(
608 const device U* C,
609 const int ldc,
610 const int fdc,
611 short2 dst_tile_dims,
612 thread const BinaryEpilogue& epilogue_op) {
613 // Adjust for simdgroup and thread location
614 C += (sm)*ldc + (sn)*fdc;
615 dst_tile_dims -= short2(sn, sm);
616
617 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
618 return;
619
620 // Loop over all simdgroup tiles
622 for (short i = 0; i < TM; i++) {
624 for (short j = 0; j < TN; j++) {
625 // Get accumulated result and associated offset in C
626 thread auto& accum = Ctile.frag_at(i, j);
627 int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
628
629 constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
630
631 // Read C
632 U c_elems[kelems] = {0};
633
635 for (short k = 0; k < kelems; k++) {
636 if ((j * TN_stride + k) < dst_tile_dims.x) {
637 c_elems[k] = C[offset_c + k * fdc];
638 }
639 }
640
641 // Apply epilogue
643 for (short k = 0; k < kelems; k++) {
644 accum[k] = epilogue_op.apply(accum[k], c_elems[k]);
645 }
646 }
647 }
648 }
649
650 /* Store results from simdgroup_matrix results into device memory */
651 METAL_FUNC void store_result(
652 device U* D,
653 const int ldd,
654 const device U* C,
655 const int ldc,
656 const int fdc,
657 thread const Epilogue& epilogue_op) const {
658 // Adjust for simdgroup and thread location
659 C += (sm)*ldc + (sn)*fdc;
660 D += (sm)*ldd + sn;
661
662 constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
663
664 // Loop over all simdgroup tiles
666 for (short i = 0; i < TM; i++) {
668 for (short j = 0; j < TN; j++) {
669 // Get accumulated result and associated offset in C
670 thread const auto& accum = Ctile.frag_at(i, j);
671 int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
672 int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
673
674 // Apply epilogue
676 for (short k = 0; k < kelems; k++) {
677 D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
678 }
679 }
680 }
681 }
682
683 METAL_FUNC void store_result_safe(
684 device U* D,
685 const int ldd,
686 const device U* C,
687 const int ldc,
688 const int fdc,
689 short2 dst_tile_dims,
690 thread const Epilogue& epilogue_op) const {
691 // Adjust for simdgroup and thread location
692 C += (sm)*ldc + (sn)*fdc;
693 D += (sm)*ldd + sn;
694 dst_tile_dims -= short2(sn, sm);
695
696 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
697 return;
698
699 constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
700
702 for (int i = 0; i < TM; i++) {
703 if (i * TM_stride < dst_tile_dims.y) {
705 for (int j = 0; j < TN; j++) {
706 // Get accumulated result and associated offset in C
707 thread const auto& accum = Ctile.frag_at(i, j);
708 int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
709 int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
710
711 // Apply epilogue
713 for (short k = 0; k < kelems; k++) {
714 if ((j * TN_stride + k) < dst_tile_dims.x) {
715 D[offset_d + k] =
716 epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
717 }
718 }
719 }
720 }
721 }
722 }
723};
724
725} // namespace steel
726} // namespace mlx
Definition bf16_math.h:226
METAL_FUNC bfloat16_t simd_shuffle_xor(bfloat16_t data, ushort mask)
Definition bf16_math.h:377
typename pointer_element< remove_cv_t< T > >::type pointer_element_t
Definition type_traits.h:51
METAL_FUNC void tile_matmad(thread MMATile< T, M, N > &D, thread MMATile< U, M, K > &A, thread MMATile< U, K, N > &B, thread MMATile< T, M, N > &C)
Definition mma.h:413
Definition allocator.h:7
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
#define STEEL_CONST
Definition defines.h:3
static METAL_FUNC constexpr void mma(thread mat_type &D, thread mat_type &A, thread mat_type &B, thread mat_type &C)
Definition mma.h:180
static METAL_FUNC constexpr void store_safe(const thread frag_type &src, DstPtrType dst, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x=Int< 0 >{}, OffY off_y=Int< 0 >{})
Definition mma.h:138
static METAL_FUNC constexpr void row_bin_op(thread frag_type &inp_vals, thread T *row_vals)
Definition mma.h:204
static METAL_FUNC constexpr void row_reduce(thread const frag_type &inp_vals, thread T *reduced_vals)
Definition mma.h:189
metal::vec< T, kElemRows > row_frag_type
Definition mma.h:62
static METAL_FUNC constexpr short2 get_coord(ushort simd_lane_id)
Definition mma.h:65
static METAL_FUNC constexpr void mma(thread frag_type &D, thread frag_type &A, thread frag_type &B, thread frag_type &C)
Definition mma.h:161
metal::simdgroup_matrix< T, kFragRows, kFragCols > mat_type
Definition mma.h:60
metal::vec< T, kElemsPerFrag > frag_type
Definition mma.h:61
static METAL_FUNC constexpr void store(const thread frag_type &src, DstPtrType dst, StrX str_x, StrY str_y)
Definition mma.h:118
metal::vec< T, kElemCols > col_frag_type
Definition mma.h:63
static METAL_FUNC constexpr void load(thread frag_type &dst, SrcPtrType src, StrX str_x, StrY str_y)
Definition mma.h:75
static METAL_FUNC constexpr void load_safe(thread frag_type &dst, SrcPtrType src, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x=Int< 0 >{}, OffY off_y=Int< 0 >{})
Definition mma.h:93
Definition mma.h:23
METAL_FUNC void store_result(device U *D, const int ldd)
Definition mma.h:536
METAL_FUNC void store_result_safe(device U *D, const int ldd, short2 dst_tile_dims)
Definition mma.h:550
short As_offset
Definition mma.h:485
MMATile< AccumType, TM, TN, MMAFrag_acc_t > Ctile
Definition mma.h:479
STEEL_CONST short A_str_k
Definition mma.h:466
MMATile< AccumType, 1, TN, MMAFrag_acc_t > Btile
Definition mma.h:478
MMATile< AccumType, TM, 1, MMAFrag_acc_t > Atile
Definition mma.h:477
STEEL_CONST short B_str_n
Definition mma.h:470
STEEL_CONST short TM_stride
Definition mma.h:455
METAL_FUNC void mma(const threadgroup T *As, const threadgroup T *Bs)
Definition mma.h:509
STEEL_CONST short TN
Definition mma.h:462
METAL_FUNC void store_result_safe(device U *D, const int ldd, const device U *C, const int ldc, const int fdc, short2 dst_tile_dims, thread const Epilogue &epilogue_op) const
Definition mma.h:683
METAL_FUNC void store_result(device U *D, const int ldd, const device U *C, const int ldc, const int fdc, thread const Epilogue &epilogue_op) const
Definition mma.h:651
METAL_FUNC void apply_epilogue(const device U *C, const int ldc, const int fdc, thread const BinaryEpilogue &epilogue_op)
Definition mma.h:579
STEEL_CONST short TN_stride
Definition mma.h:457
STEEL_CONST short tile_stride_a
Definition mma.h:473
short Bs_offset
Definition mma.h:486
METAL_FUNC void apply_epilogue_safe(const device U *C, const int ldc, const int fdc, short2 dst_tile_dims, thread const BinaryEpilogue &epilogue_op)
Definition mma.h:607
METAL_FUNC BlockMMA(ushort simd_group_id, ushort simd_lane_id)
Definition mma.h:489
STEEL_CONST short B_str_k
Definition mma.h:469
short sm
Definition mma.h:482
STEEL_CONST short A_str_m
Definition mma.h:465
STEEL_CONST short TM
Definition mma.h:460
short sn
Definition mma.h:483
STEEL_CONST short tile_stride_b
Definition mma.h:474
STEEL_CONST short kFragSize
Definition mma.h:451
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue &epilogue_op)
Definition mma.h:569
Definition mma.h:31
Shape shape
Definition mma.h:32
Layout layout
Definition mma.h:33
Definition mma.h:178
METAL_FUNC constexpr thread frag_type & frag_at(const short i, const short j)
Definition mma.h:256
STEEL_CONST int kTileRows
Definition mma.h:230
STEEL_CONST int kColsPerThread
Definition mma.h:240
MMAFrag_t::mat_type mat_type
Definition mma.h:242
METAL_FUNC void store(threadgroup U *dst) const
Definition mma.h:325
METAL_FUNC mat_type mat_at(const short i, const short j)
Definition mma.h:266
METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread])
Definition mma.h:296
STEEL_CONST int kTileCols
Definition mma.h:231
METAL_FUNC void store_safe(device U *dst, const int ld, const short2 dst_tile_dims) const
Definition mma.h:393
STEEL_CONST int kFragRows
Definition mma.h:226
STEEL_CONST int kRowsPerThread
Definition mma.h:239
STEEL_CONST int kRows
Definition mma.h:233
frag_type val_frags[kNumFrags]
Definition mma.h:245
MMAFrag_ MMAFrag_t
Definition mma.h:224
METAL_FUNC void store(device U *dst, const int ld) const
Definition mma.h:357
T elem_type
Definition mma.h:225
METAL_FUNC thread elem_type * elems()
Definition mma.h:275
STEEL_CONST int kCols
Definition mma.h:234
STEEL_CONST int kElemsPerTile
Definition mma.h:237
METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const
Definition mma.h:284
METAL_FUNC void load_safe(const device U *src, const int ld, const short2 src_tile_dims)
Definition mma.h:373
METAL_FUNC MMATile() thread
Definition mma.h:247
METAL_FUNC void load(const threadgroup U *src)
Definition mma.h:308
METAL_FUNC constexpr void clear()
Definition mma.h:249
METAL_FUNC void load(const device U *src, const int ld)
Definition mma.h:342
MMAFrag_t::frag_type frag_type
Definition mma.h:243
STEEL_CONST int kFragCols
Definition mma.h:227
METAL_FUNC constexpr const thread frag_type & frag_at(const short i, const short j) const
Definition mma.h:260
METAL_FUNC const thread elem_type * elems() const
Definition mma.h:279
STEEL_CONST int kNumFrags
Definition mma.h:236
STEEL_CONST int kElemsPerFrag
Definition mma.h:228
Definition mma.h:23
Shape2D(RInt r_, CInt c_)
Definition mma.h:27
RInt r
Definition mma.h:24
CInt c
Definition mma.h:25
Definition integral_constant.h:18