MLX
 
Loading...
Searching...
No Matches
mma.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
5#include <metal_simdgroup>
6#include <metal_simdgroup_matrix>
7#include <metal_stdlib>
8
12
13using namespace metal;
14
16// MMA helper
18
19namespace mlx {
20namespace steel {
21
22template <typename RInt, typename CInt>
23struct Shape2D {
24 RInt r;
25 CInt c;
26
27 Shape2D(RInt r_, CInt c_) : r(r_), c(c_) {}
28};
29
30template <typename Shape, typename Layout>
31struct Layout2D {
32 Shape shape;
33 Layout layout;
34};
35
36template <typename T, int kFragRows_, int kFragCols_>
38 static_assert(
39 kFragRows_ == 8,
40 "Only 8 x 8 fragment matrices are currently supported");
41 static_assert(
42 kFragCols_ == 8,
43 "Only 8 x 8 fragment matrices are currently supported");
44};
45
46template <typename T>
47struct BaseMMAFrag<T, 8, 8> {
50
52
55
56 static_assert(
58 "MMAFrag shape is not consistent with MMAFrag size");
59
60 typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;
61 typedef metal::vec<T, kElemsPerFrag> frag_type;
62 typedef metal::vec<T, kElemRows> row_frag_type;
63 typedef metal::vec<T, kElemCols> col_frag_type;
64
65 template <typename U>
66 using dtype_mat_t = typename metal::simdgroup_matrix<U, kFragRows, kFragCols>;
67
68 template <typename U>
69 using dtype_frag_t = typename metal::vec<U, kElemsPerFrag>;
70
71 METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
72 [[thread_index_in_simdgroup]]) {
73 const short qid = simd_lane_id / 4;
74 const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);
75 const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
76 return short2{fn, fm};
77 }
78
79 template <typename SrcPtrType, typename StrX, typename StrY>
80 METAL_FUNC static constexpr void
81 load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) {
83 for (short i = 0; i < kElemRows; i++) {
85 for (short j = 0; j < kElemCols; j++) {
86 dst[i * kElemCols + j] = static_cast<T>(src[i * str_x + j * str_y]);
87 }
88 }
89 }
90
91 template <
92 typename SrcPtrType,
93 typename StrX,
94 typename StrY,
95 typename LimX,
96 typename LimY,
97 typename OffX,
98 typename OffY>
99 METAL_FUNC static constexpr void load_safe(
100 thread frag_type& dst,
101 SrcPtrType src,
102 StrX str_x,
103 StrY str_y,
104 LimX lim_x,
105 LimY lim_y,
106 OffX off_x = Int<0>{},
107 OffY off_y = Int<0>{}) {
109 for (short i = 0; i < kElemRows; i++) {
111 for (short j = 0; j < kElemCols; j++) {
112 if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
113 dst[i * kElemCols + j] =
114 static_cast<T>(src[(off_x + i) * str_x + (off_y + j) * str_y]);
115 } else {
116 dst[i * kElemCols + j] = T(0);
117 }
118 }
119 }
120 }
121
122 template <typename DstPtrType, typename StrX, typename StrY>
123 METAL_FUNC static constexpr void
124 store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) {
126
128 for (short i = 0; i < kElemRows; i++) {
130 for (short j = 0; j < kElemCols; j++) {
131 dst[i * str_x + j * str_y] = static_cast<U>(src[i * kElemCols + j]);
132 }
133 }
134 }
135
136 template <
137 typename DstPtrType,
138 typename StrX,
139 typename StrY,
140 typename LimX,
141 typename LimY,
142 typename OffX,
143 typename OffY>
144 METAL_FUNC static constexpr void store_safe(
145 const thread frag_type& src,
146 DstPtrType dst,
147 StrX str_x,
148 StrY str_y,
149 LimX lim_x,
150 LimY lim_y,
151 OffX off_x = Int<0>{},
152 OffY off_y = Int<0>{}) {
154
156 for (short i = 0; i < kElemRows; i++) {
158 for (short j = 0; j < kElemCols; j++) {
159 if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
160 dst[(off_x + i) * str_x + (off_y + j) * str_y] =
161 static_cast<U>(src[i * kElemCols + j]);
162 }
163 }
164 }
165 }
166
167 template <typename Atype, typename Btype, typename Ctype>
168 METAL_FUNC static constexpr void mma(
169 thread frag_type& D,
170 thread dtype_frag_t<Atype>& A,
171 thread dtype_frag_t<Btype>& B,
172 thread dtype_frag_t<Ctype>& C) {
173 mat_type D_mat;
174 dtype_mat_t<Atype> A_mat;
175 dtype_mat_t<Btype> B_mat;
176 dtype_mat_t<Ctype> C_mat;
177
178 reinterpret_cast<thread dtype_frag_t<Atype>&>(A_mat.thread_elements()) = A;
179 reinterpret_cast<thread dtype_frag_t<Btype>&>(B_mat.thread_elements()) = B;
180 reinterpret_cast<thread dtype_frag_t<Ctype>&>(C_mat.thread_elements()) = C;
181
182 mma(D_mat, A_mat, B_mat, C_mat);
183
184 D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements());
185 }
186
187 template <typename Atype, typename Btype, typename Ctype>
188 METAL_FUNC static constexpr void mma(
189 thread mat_type& D,
190 thread dtype_mat_t<Atype>& A,
191 thread dtype_mat_t<Btype>& B,
192 thread dtype_mat_t<Ctype>& C) {
193 simdgroup_multiply_accumulate(D, A, B, C);
194 }
195
196 template <typename Op>
197 METAL_FUNC static constexpr void row_reduce(
198 thread const frag_type& inp_vals,
199 thread T* reduced_vals) {
200 T thr_reduce = Op::apply(inp_vals.x, inp_vals.y);
201
202 T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1));
203 qgr_reduce = Op::apply(thr_reduce, qgr_reduce);
204
205 T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8));
206 sgr_reduce = Op::apply(qgr_reduce, sgr_reduce);
207
208 reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce);
209 }
210
211 template <typename Op>
212 METAL_FUNC static constexpr void row_bin_op(
213 thread frag_type& inp_vals,
214 thread T* row_vals) {
216 for (short i = 0; i < kElemRows; i++) {
218 for (short j = 0; j < kElemCols; j++) {
219 inp_vals[i * kElemCols + j] =
220 Op::apply(inp_vals[i * kElemCols + j], row_vals[i]);
221 }
222 }
223 }
224};
225
226template <
227 typename T,
228 int kTileRows_,
229 int kTileCols_,
230 class MMAFrag_ = BaseMMAFrag<T, 8, 8>>
231struct MMATile {
232 using MMAFrag_t = MMAFrag_;
233 using elem_type = T;
234 STEEL_CONST int kFragRows = MMAFrag_t::kFragRows;
235 STEEL_CONST int kFragCols = MMAFrag_t::kFragCols;
236 STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag;
237
238 STEEL_CONST int kTileRows = kTileRows_;
239 STEEL_CONST int kTileCols = kTileCols_;
240
243
246
247 STEEL_CONST int kRowsPerThread = kTileRows * MMAFrag_t::kElemRows;
248 STEEL_CONST int kColsPerThread = kTileCols * MMAFrag_t::kElemCols;
249
250 typedef typename MMAFrag_t::mat_type mat_type;
251 typedef typename MMAFrag_t::frag_type frag_type;
252
253 frag_type val_frags[kNumFrags]; // = {frag_type(0)};
254
255 METAL_FUNC MMATile() thread {}
256
257 METAL_FUNC constexpr void clear() {
259 for (short i = 0; i < kNumFrags; ++i) {
260 val_frags[i] = frag_type(0);
261 }
262 }
263
264 METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {
265 return val_frags[i * kTileCols + j];
266 }
267
268 METAL_FUNC constexpr const thread frag_type& frag_at(
269 const short i,
270 const short j) const {
271 return val_frags[i * kTileCols + j];
272 }
273
274 METAL_FUNC mat_type mat_at(const short i, const short j) {
275 mat_type val_mat;
277 for (short ii = 0; ii < kElemsPerFrag; ++ii) {
278 val_mat.thread_elements()[ii] = frag_at(i, j)[ii];
279 }
280 return val_mat;
281 }
282
283 METAL_FUNC thread elem_type* elems() {
284 return reinterpret_cast<thread elem_type*>(val_frags);
285 }
286
287 METAL_FUNC const thread elem_type* elems() const {
288 return reinterpret_cast<const thread elem_type*>(val_frags);
289 }
290
291 template <typename Op>
292 METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const {
294 for (short i = 0; i < kTileRows; ++i) {
296 for (short j = 0; j < kTileCols; ++j) {
297 MMAFrag_t::template row_reduce<Op>(
298 frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
299 }
300 }
301 }
302
303 template <typename Op>
304 METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread]) {
306 for (short i = 0; i < kTileRows; ++i) {
308 for (short j = 0; j < kTileCols; ++j) {
309 MMAFrag_t::template row_bin_op<Op>(
310 frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
311 }
312 }
313 }
314
315 template <typename U, int w_x, int w_y, int str_x, int str_y>
316 METAL_FUNC void load(const threadgroup U* src) {
318 for (short i = 0; i < kTileRows; ++i) {
320 for (short j = 0; j < kTileCols; ++j) {
321 MMAFrag_t::load(
322 frag_at(i, j),
323 &(
324 src[(i * kFragRows) * w_x * str_x +
325 (j * kFragCols) * w_y * str_y]),
326 Int<str_x>{},
327 Int<str_y>{});
328 }
329 }
330 }
331
332 template <typename U, int w_x, int w_y, int str_x, int str_y>
333 METAL_FUNC void store(threadgroup U* dst) const {
335 for (short i = 0; i < kTileRows; ++i) {
337 for (short j = 0; j < kTileCols; ++j) {
338 MMAFrag_t::store(
339 frag_at(i, j),
340 &(
341 dst[(i * kFragRows) * w_x * str_x +
342 (j * kFragCols) * w_y * str_y]),
343 Int<str_x>{},
344 Int<str_y>{});
345 }
346 }
347 }
348
349 template <typename U, int w_x, int w_y>
350 METAL_FUNC void load(const device U* src, const int ld) {
352 for (short i = 0; i < kTileRows; ++i) {
354 for (short j = 0; j < kTileCols; ++j) {
355 MMAFrag_t::load(
356 frag_at(i, j),
357 &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
358 ld,
359 Int<1>{});
360 }
361 }
362 }
363
364 template <typename U, int w_x, int w_y>
365 METAL_FUNC void store(device U* dst, const int ld) const {
367 for (short i = 0; i < kTileRows; ++i) {
369 for (short j = 0; j < kTileCols; ++j) {
370 MMAFrag_t::store(
371 frag_at(i, j),
372 &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
373 ld,
374 Int<1>{});
375 }
376 }
377 }
378
379 template <typename U, int w_x, int w_y>
380 METAL_FUNC void
381 load_safe(const device U* src, const int ld, const short2 src_tile_dims) {
383 for (int i = 0; i < kTileRows; ++i) {
385 for (int j = 0; j < kTileCols; ++j) {
386 MMAFrag_t::load_safe(
387 frag_at(i, j),
388 src,
389 ld,
390 Int<1>{},
391 src_tile_dims.y,
392 src_tile_dims.x,
393 (i * kFragRows) * w_x,
394 (j * kFragCols) * w_y);
395 }
396 }
397 }
398
399 template <typename U, int w_x, int w_y>
400 METAL_FUNC void
401 store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {
403 for (int i = 0; i < kTileRows; ++i) {
405 for (int j = 0; j < kTileCols; ++j) {
406 MMAFrag_t::store_safe(
407 frag_at(i, j),
408 dst,
409 ld,
410 Int<1>{},
411 dst_tile_dims.y,
412 dst_tile_dims.x,
413 (i * kFragRows) * w_x,
414 (j * kFragCols) * w_y);
415 }
416 }
417 }
418};
419
420template <
421 typename Dtype,
422 typename Atype,
423 typename Btype,
424 typename Ctype,
425 int M,
426 int N,
427 int K,
428 class MMAFragD,
429 class MMAFragA,
430 class MMAFragB,
431 class MMAFragC>
432METAL_FUNC void tile_matmad(
438 for (short m = 0; m < M; ++m) {
440 for (short n = 0; n < N; ++n) {
441 short m_serp = m; //(n % 2) ? (M - 1 - m) : m;
442 short n_serp = (m % 2) ? (N - 1 - n) : n;
443
445 for (short k = 0; k < K; ++k) {
446 MMAFragD::mma(
447 D.frag_at(m_serp, n_serp),
448 A.frag_at(m_serp, k),
449 B.frag_at(k, n_serp),
450 C.frag_at(m_serp, n_serp));
451 }
452 }
453 }
454}
455
456template <
457 typename T,
458 typename U,
459 int BM,
460 int BN,
461 int BK,
462 int WM,
463 int WN,
464 bool transpose_a,
465 bool transpose_b,
466 short lda_tgp,
467 short ldb_tgp,
468 typename AccumType = float,
469 typename Epilogue = TransformNone<U, AccumType>>
470struct BlockMMA {
471 // MMAFrag size
474
475 // Warp tile simdgroup matrix strides along M
477 // Warp tile simdgroup matrix strides along M
479
480 // Warp tile size along M
482 // Warp tile size along N
484
485 // Threadgroup A strides
486 STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
487 STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
488
489 // Threadgroup B strides
490 STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
491 STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
492
493 // Threadgroup strides along K
496
497 // Simdgroup matrices
501
502 // Offsets within threadgroup
503 short sm;
504 short sn;
505
508
509 /* Constructor */
510 METAL_FUNC BlockMMA(
511 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
512 ushort simd_lane_id [[thread_index_in_simdgroup]]) {
513 // Determine thread position in simdgroup matrix
514 short tm = kFragSize * (simd_group_id / WN);
515 short tn = kFragSize * (simd_group_id % WN);
516
517 short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
518 sm = simd_coord.y;
519 sn = simd_coord.x;
520
521 // Determine thread and simdgroup offset
522 As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K
523 Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N
524
525 sm += tm;
526 sn += tn;
527 }
528
529 /* (BM, BK) X (BK, BN) multiply accumulate function */
530 METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
531 // Adjust for simdgroup and thread location
532 As += As_offset;
533 Bs += Bs_offset;
534
535 // Iterate over BK in blocks of kFragSize
537 for (short kk = 0; kk < BK; kk += kFragSize) {
538 simdgroup_barrier(mem_flags::mem_none);
539
540 Atile.template load<T, WM, 1, A_str_m, A_str_k>(As);
541
542 simdgroup_barrier(mem_flags::mem_none);
543
544 Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs);
545
546 simdgroup_barrier(mem_flags::mem_none);
547
549
550 // Progress to next simdgroup tile
551 As += tile_stride_a;
552 Bs += tile_stride_b;
553 }
554 }
555
556 /* Store results from simdgroup_matrix results into device memory */
557 METAL_FUNC void store_result(device U* D, const int ldd) {
558 // Apply epilogue
560 for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
561 Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
562 }
563
564 // Adjust for simdgroup and thread location
565 D += sm * ldd + sn;
566
567 Ctile.template store<U, WM, WN>(D, ldd);
568 }
569
570 METAL_FUNC void
571 store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
572 // Apply epilogue
574 for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
575 Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
576 }
577
578 // Adjust for simdgroup and thread location
579 D += sm * ldd + sn;
580 dst_tile_dims -= short2(sn, sm);
581
582 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
583 return;
584
585 Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims);
586 }
587
588 /* Apply epilogue */
589 template <typename UnaryEpilogue>
590 METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
591 // Loop over all simdgroup tiles
593 for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
594 Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]);
595 }
596 }
597
598 /* Apply epilogue */
599 template <typename BinaryEpilogue>
600 METAL_FUNC void apply_epilogue(
601 const device U* C,
602 const int ldc,
603 const int fdc,
604 thread const BinaryEpilogue& epilogue_op) {
605 // Adjust for simdgroup and thread location
606 C += (sm)*ldc + (sn)*fdc;
607
608 // Loop over all simdgroup tiles
610 for (short i = 0; i < TM; i++) {
612 for (short j = 0; j < TN; j++) {
613 // Get accumulated result and associated offset in C
614 thread auto& accum = Ctile.frag_at(i, j);
615 int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
616
617 // Apply epilogue
619 for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) {
620 accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
621 }
622 }
623 }
624 }
625
626 /* Apply epilogue */
627 template <typename BinaryEpilogue>
628 METAL_FUNC void apply_epilogue_safe(
629 const device U* C,
630 const int ldc,
631 const int fdc,
632 short2 dst_tile_dims,
633 thread const BinaryEpilogue& epilogue_op) {
634 // Adjust for simdgroup and thread location
635 C += (sm)*ldc + (sn)*fdc;
636 dst_tile_dims -= short2(sn, sm);
637
638 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
639 return;
640
641 // Loop over all simdgroup tiles
643 for (short i = 0; i < TM; i++) {
645 for (short j = 0; j < TN; j++) {
646 // Get accumulated result and associated offset in C
647 thread auto& accum = Ctile.frag_at(i, j);
648 int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
649
650 constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
651
652 // Read C
653 U c_elems[kelems] = {0};
654
656 for (short k = 0; k < kelems; k++) {
657 if ((j * TN_stride + k) < dst_tile_dims.x) {
658 c_elems[k] = C[offset_c + k * fdc];
659 }
660 }
661
662 // Apply epilogue
664 for (short k = 0; k < kelems; k++) {
665 accum[k] = epilogue_op.apply(accum[k], c_elems[k]);
666 }
667 }
668 }
669 }
670
671 /* Store results from simdgroup_matrix results into device memory */
672 METAL_FUNC void store_result(
673 device U* D,
674 const int ldd,
675 const device U* C,
676 const int ldc,
677 const int fdc,
678 thread const Epilogue& epilogue_op) const {
679 // Adjust for simdgroup and thread location
680 C += (sm)*ldc + (sn)*fdc;
681 D += (sm)*ldd + sn;
682
683 constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
684
685 // Loop over all simdgroup tiles
687 for (short i = 0; i < TM; i++) {
689 for (short j = 0; j < TN; j++) {
690 // Get accumulated result and associated offset in C
691 thread const auto& accum = Ctile.frag_at(i, j);
692 int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
693 int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
694
695 // Apply epilogue
697 for (short k = 0; k < kelems; k++) {
698 D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
699 }
700 }
701 }
702 }
703
704 METAL_FUNC void store_result_safe(
705 device U* D,
706 const int ldd,
707 const device U* C,
708 const int ldc,
709 const int fdc,
710 short2 dst_tile_dims,
711 thread const Epilogue& epilogue_op) const {
712 // Adjust for simdgroup and thread location
713 C += (sm)*ldc + (sn)*fdc;
714 D += (sm)*ldd + sn;
715 dst_tile_dims -= short2(sn, sm);
716
717 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
718 return;
719
720 constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
721
723 for (int i = 0; i < TM; i++) {
724 if (i * TM_stride < dst_tile_dims.y) {
726 for (int j = 0; j < TN; j++) {
727 // Get accumulated result and associated offset in C
728 thread const auto& accum = Ctile.frag_at(i, j);
729 int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
730 int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
731
732 // Apply epilogue
734 for (short k = 0; k < kelems; k++) {
735 if ((j * TN_stride + k) < dst_tile_dims.x) {
736 D[offset_d + k] =
737 epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
738 }
739 }
740 }
741 }
742 }
743 }
744};
745
746} // namespace steel
747} // namespace mlx
Definition bf16_math.h:226
METAL_FUNC bfloat16_t simd_shuffle_xor(bfloat16_t data, ushort mask)
Definition bf16_math.h:377
typename pointer_element< remove_cv_t< T > >::type pointer_element_t
Definition type_traits.h:51
Definition attn.h:19
METAL_FUNC void tile_matmad(thread MMATile< Dtype, M, N, MMAFragD > &D, thread MMATile< Atype, M, K, MMAFragA > &A, thread MMATile< Btype, K, N, MMAFragB > &B, thread MMATile< Ctype, M, N, MMAFragC > &C)
Definition mma.h:432
integral_constant< int, val > Int
Definition integral_constant.h:48
Definition allocator.h:7
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
#define STEEL_CONST
Definition defines.h:3
static METAL_FUNC constexpr void store_safe(const thread frag_type &src, DstPtrType dst, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x=Int< 0 >{}, OffY off_y=Int< 0 >{})
Definition mma.h:144
STEEL_CONST int kFragCols
Definition mma.h:49
metal::simdgroup_matrix< T, kFragRows, kFragCols > mat_type
Definition mma.h:60
STEEL_CONST int kFragRows
Definition mma.h:48
static METAL_FUNC constexpr void row_bin_op(thread frag_type &inp_vals, thread T *row_vals)
Definition mma.h:212
STEEL_CONST int kElemsPerFrag
Definition mma.h:51
metal::vec< T, kElemRows > row_frag_type
Definition mma.h:62
static METAL_FUNC constexpr void row_reduce(thread const frag_type &inp_vals, thread T *reduced_vals)
Definition mma.h:197
typename metal::vec< U, kElemsPerFrag > dtype_frag_t
Definition mma.h:69
static METAL_FUNC constexpr short2 get_coord(ushort simd_lane_id)
Definition mma.h:71
STEEL_CONST int kElemRows
Definition mma.h:53
STEEL_CONST int kElemCols
Definition mma.h:54
metal::vec< T, kElemsPerFrag > frag_type
Definition mma.h:61
typename metal::simdgroup_matrix< U, kFragRows, kFragCols > dtype_mat_t
Definition mma.h:66
static METAL_FUNC constexpr void store(const thread frag_type &src, DstPtrType dst, StrX str_x, StrY str_y)
Definition mma.h:124
static METAL_FUNC constexpr void load(thread frag_type &dst, SrcPtrType src, StrX str_x, StrY str_y)
Definition mma.h:81
static METAL_FUNC constexpr void load_safe(thread frag_type &dst, SrcPtrType src, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x=Int< 0 >{}, OffY off_y=Int< 0 >{})
Definition mma.h:99
metal::vec< T, kElemCols > col_frag_type
Definition mma.h:63
static METAL_FUNC constexpr void mma(thread frag_type &D, thread dtype_frag_t< Atype > &A, thread dtype_frag_t< Btype > &B, thread dtype_frag_t< Ctype > &C)
Definition mma.h:168
static METAL_FUNC constexpr void mma(thread mat_type &D, thread dtype_mat_t< Atype > &A, thread dtype_mat_t< Btype > &B, thread dtype_mat_t< Ctype > &C)
Definition mma.h:188
Definition mma.h:37
METAL_FUNC void store_result(device U *D, const int ldd)
Definition mma.h:557
METAL_FUNC void store_result_safe(device U *D, const int ldd, short2 dst_tile_dims)
Definition mma.h:571
METAL_FUNC void mma(const threadgroup T *As, const threadgroup T *Bs)
Definition mma.h:530
METAL_FUNC void store_result_safe(device U *D, const int ldd, const device U *C, const int ldc, const int fdc, short2 dst_tile_dims, thread const Epilogue &epilogue_op) const
Definition mma.h:704
METAL_FUNC void store_result(device U *D, const int ldd, const device U *C, const int ldc, const int fdc, thread const Epilogue &epilogue_op) const
Definition mma.h:672
METAL_FUNC void apply_epilogue(const device U *C, const int ldc, const int fdc, thread const BinaryEpilogue &epilogue_op)
Definition mma.h:600
METAL_FUNC void apply_epilogue_safe(const device U *C, const int ldc, const int fdc, short2 dst_tile_dims, thread const BinaryEpilogue &epilogue_op)
Definition mma.h:628
METAL_FUNC BlockMMA(ushort simd_group_id, ushort simd_lane_id)
Definition mma.h:510
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue &epilogue_op)
Definition mma.h:590
Definition mma.h:31
Shape shape
Definition mma.h:32
Layout layout
Definition mma.h:33
Definition mma.h:231
METAL_FUNC constexpr thread frag_type & frag_at(const short i, const short j)
Definition mma.h:264
STEEL_CONST int kTileRows
Definition mma.h:238
STEEL_CONST int kColsPerThread
Definition mma.h:248
MMAFrag_t::mat_type mat_type
Definition mma.h:250
METAL_FUNC void store(threadgroup U *dst) const
Definition mma.h:333
METAL_FUNC mat_type mat_at(const short i, const short j)
Definition mma.h:274
METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread])
Definition mma.h:304
STEEL_CONST int kTileCols
Definition mma.h:239
METAL_FUNC void store_safe(device U *dst, const int ld, const short2 dst_tile_dims) const
Definition mma.h:401
STEEL_CONST int kFragRows
Definition mma.h:234
STEEL_CONST int kRowsPerThread
Definition mma.h:247
STEEL_CONST int kRows
Definition mma.h:241
frag_type val_frags[kNumFrags]
Definition mma.h:253
METAL_FUNC void store(device U *dst, const int ld) const
Definition mma.h:365
T elem_type
Definition mma.h:233
METAL_FUNC thread elem_type * elems()
Definition mma.h:283
STEEL_CONST int kCols
Definition mma.h:242
STEEL_CONST int kElemsPerTile
Definition mma.h:245
METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const
Definition mma.h:292
METAL_FUNC void load_safe(const device U *src, const int ld, const short2 src_tile_dims)
Definition mma.h:381
METAL_FUNC MMATile() thread
Definition mma.h:255
METAL_FUNC void load(const threadgroup U *src)
Definition mma.h:316
METAL_FUNC constexpr void clear()
Definition mma.h:257
METAL_FUNC void load(const device U *src, const int ld)
Definition mma.h:350
MMAFrag_t::frag_type frag_type
Definition mma.h:251
MMAFrag_ MMAFrag_t
Definition mma.h:232
STEEL_CONST int kFragCols
Definition mma.h:235
METAL_FUNC constexpr const thread frag_type & frag_at(const short i, const short j) const
Definition mma.h:268
METAL_FUNC const thread elem_type * elems() const
Definition mma.h:287
STEEL_CONST int kNumFrags
Definition mma.h:244
STEEL_CONST int kElemsPerFrag
Definition mma.h:236
Shape2D(RInt r_, CInt c_)
Definition mma.h:27
RInt r
Definition mma.h:24
CInt c
Definition mma.h:25