MLX
Loading...
Searching...
No Matches
mma.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
5#include <metal_simdgroup>
6#include <metal_simdgroup_matrix>
7#include <metal_stdlib>
8
12
13using namespace metal;
14
16// MMA helper
18
19namespace mlx {
20namespace steel {
21
22template <typename T, int kFragRows_, int kFragCols_>
24 static_assert(
25 kFragRows_ == 8,
26 "Only 8 x 8 fragment matrices are currently supported");
27 static_assert(
28 kFragCols_ == 8,
29 "Only 8 x 8 fragment matrices are currently supported");
30};
31
32template <typename T>
33struct BaseMMAFrag<T, 8, 8> {
34 STEEL_CONST int kFragRows = 8;
35 STEEL_CONST int kFragCols = 8;
36
37 STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32;
38
39 STEEL_CONST int kElemRows = 1;
40 STEEL_CONST int kElemCols = 2;
41
42 static_assert(
43 kElemRows * kElemCols == kElemsPerFrag,
44 "MMAFrag shape is not consistent with MMAFrag size");
45
46 typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;
47 typedef metal::vec<T, kElemsPerFrag> frag_type;
48
49 METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
50 [[thread_index_in_simdgroup]]) {
51 const short qid = simd_lane_id / 4;
52 const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);
53 const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
54 return short2{fn, fm};
55 }
56
57 template <typename SrcPtrType, typename StrX, typename StrY>
58 METAL_FUNC static constexpr void
59 load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) {
61 for (short i = 0; i < kElemRows; i++) {
63 for (short j = 0; j < kElemCols; j++) {
64 dst[i * kElemCols + j] = static_cast<T>(src[i * str_x + j * str_y]);
65 }
66 }
67 }
68
69 template <
70 typename SrcPtrType,
71 typename StrX,
72 typename StrY,
73 typename LimX,
74 typename LimY,
75 typename OffX,
76 typename OffY>
77 METAL_FUNC static constexpr void load_safe(
78 thread frag_type& dst,
79 SrcPtrType src,
80 StrX str_x,
81 StrY str_y,
82 LimX lim_x,
83 LimY lim_y,
84 OffX off_x = Int<0>{},
85 OffY off_y = Int<0>{}) {
87 for (short i = 0; i < kElemRows; i++) {
89 for (short j = 0; j < kElemCols; j++) {
90 if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
91 dst[i * kElemCols + j] =
92 static_cast<T>(src[(off_x + i) * str_x + (off_x + j) * str_y]);
93 } else {
94 dst[i * kElemCols + j] = T(0);
95 }
96 }
97 }
98 }
99
100 template <typename DstPtrType, typename StrX, typename StrY>
101 METAL_FUNC static constexpr void
102 store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) {
104
106 for (short i = 0; i < kElemRows; i++) {
108 for (short j = 0; j < kElemCols; j++) {
109 dst[i * str_x + j * str_y] = static_cast<U>(src[i * kElemCols + j]);
110 }
111 }
112 }
113
114 template <
115 typename DstPtrType,
116 typename StrX,
117 typename StrY,
118 typename LimX,
119 typename LimY,
120 typename OffX,
121 typename OffY>
122 METAL_FUNC static constexpr void store_safe(
123 const thread frag_type& src,
124 DstPtrType dst,
125 StrX str_x,
126 StrY str_y,
127 LimX lim_x,
128 LimY lim_y,
129 OffX off_x = Int<0>{},
130 OffY off_y = Int<0>{}) {
132
134 for (short i = 0; i < kElemRows; i++) {
136 for (short j = 0; j < kElemCols; j++) {
137 if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
138 dst[(off_x + i) * str_x + (off_y + j) * str_y] =
139 static_cast<U>(src[i * kElemCols + j]);
140 }
141 }
142 }
143 }
144
145 METAL_FUNC static constexpr void mma(
146 thread frag_type& D,
147 thread frag_type& A,
148 thread frag_type& B,
149 thread frag_type& C) {
150 mat_type D_mat;
151 mat_type A_mat;
152 mat_type B_mat;
153 mat_type C_mat;
154
155 reinterpret_cast<thread frag_type&>(A_mat.thread_elements()) = A;
156 reinterpret_cast<thread frag_type&>(B_mat.thread_elements()) = B;
157 reinterpret_cast<thread frag_type&>(C_mat.thread_elements()) = C;
158
159 mma(D_mat, A_mat, B_mat, C_mat);
160
161 D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements());
162 }
163
164 METAL_FUNC static constexpr void mma(
165 thread mat_type& D,
166 thread mat_type& A,
167 thread mat_type& B,
168 thread mat_type& C) {
169 simdgroup_multiply_accumulate(D, A, B, C);
170 }
171};
172
173template <
174 typename T,
175 int kTileRows_,
176 int kTileCols_,
177 class MMAFrag_ = BaseMMAFrag<T, 8, 8>>
178struct MMATile {
179 using MMAFrag_t = MMAFrag_;
180 using elem_type = T;
181 STEEL_CONST int kFragRows = MMAFrag_t::kFragRows;
182 STEEL_CONST int kFragCols = MMAFrag_t::kFragCols;
183 STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag;
184
185 STEEL_CONST int kTileRows = kTileRows_;
186 STEEL_CONST int kTileCols = kTileCols_;
187
190
193
194 typedef typename MMAFrag_t::mat_type mat_type;
195 typedef typename MMAFrag_t::frag_type frag_type;
196
198
199 METAL_FUNC MMATile() thread {}
200
201 METAL_FUNC constexpr void clear() {
203 for (short i = 0; i < kNumFrags; ++i) {
204 val_frags[i] = frag_type(0);
205 }
206 }
207
208 METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {
209 return val_frags[i * kTileCols + j];
210 }
211
212 METAL_FUNC constexpr const thread frag_type& frag_at(
213 const short i,
214 const short j) const {
215 return val_frags[i * kTileCols + j];
216 }
217
218 METAL_FUNC mat_type mat_at(const short i, const short j) {
219 mat_type val_mat;
221 for (short ii = 0; ii < kElemsPerFrag; ++ii) {
222 val_mat.thread_elements()[ii] = frag_at(i, j)[ii];
223 }
224 return val_mat;
225 }
226
227 METAL_FUNC thread elem_type* elems() {
228 return reinterpret_cast<thread elem_type*>(val_frags);
229 }
230
231 METAL_FUNC const thread elem_type* elems() const {
232 return reinterpret_cast<const thread elem_type*>(val_frags);
233 }
234
235 template <typename U, int w_x, int w_y, int str_x, int str_y>
236 METAL_FUNC void load(const threadgroup U* src) {
238 for (short i = 0; i < kTileRows; ++i) {
240 for (short j = 0; j < kTileCols; ++j) {
241 MMAFrag_t::load(
242 frag_at(i, j),
243 &(
244 src[(i * kFragRows) * w_x * str_x +
245 (j * kFragCols) * w_y * str_y]),
246 Int<str_x>{},
247 Int<str_y>{});
248 }
249 }
250 }
251
252 template <typename U, int w_x, int w_y, int str_x, int str_y>
253 METAL_FUNC void store(threadgroup U* dst) const {
255 for (short i = 0; i < kTileRows; ++i) {
257 for (short j = 0; j < kTileCols; ++j) {
258 MMAFrag_t::store(
259 frag_at(i, j),
260 &(
261 dst[(i * kFragRows) * w_x * str_x +
262 (j * kFragCols) * w_y * str_y]),
263 Int<str_x>{},
264 Int<str_y>{});
265 }
266 }
267 }
268
269 template <typename U, int w_x, int w_y>
270 METAL_FUNC void load(const device U* src, const int ld) {
272 for (short i = 0; i < kTileRows; ++i) {
274 for (short j = 0; j < kTileCols; ++j) {
275 MMAFrag_t::load(
276 frag_at(i, j),
277 &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
278 ld,
279 Int<1>{});
280 }
281 }
282 }
283
284 template <typename U, int w_x, int w_y>
285 METAL_FUNC void store(device U* dst, const int ld) const {
287 for (short i = 0; i < kTileRows; ++i) {
289 for (short j = 0; j < kTileCols; ++j) {
290 MMAFrag_t::store(
291 frag_at(i, j),
292 &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
293 ld,
294 Int<1>{});
295 }
296 }
297 }
298
299 template <typename U, int w_x, int w_y>
300 METAL_FUNC void
301 load_safe(const device U* src, const int ld, const short2 src_tile_dims) {
303 for (int i = 0; i < kTileRows; ++i) {
305 for (int j = 0; j < kTileCols; ++j) {
306 MMAFrag_t::load_safe(
307 frag_at(i, j),
308 src,
309 ld,
310 Int<1>{},
311 src_tile_dims.y,
312 src_tile_dims.x,
313 (i * kFragRows) * w_x,
314 (j * kFragCols) * w_y);
315 }
316 }
317 }
318
319 template <typename U, int w_x, int w_y>
320 METAL_FUNC void
321 store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {
323 for (int i = 0; i < kTileRows; ++i) {
325 for (int j = 0; j < kTileCols; ++j) {
326 MMAFrag_t::store_safe(
327 frag_at(i, j),
328 dst,
329 ld,
330 Int<1>{},
331 dst_tile_dims.y,
332 dst_tile_dims.x,
333 (i * kFragRows) * w_x,
334 (j * kFragCols) * w_y);
335 }
336 }
337 }
338};
339
340template <typename T, typename U, int M, int N, int K>
341METAL_FUNC void tile_matmad(
342 thread MMATile<T, M, N>& D,
343 thread MMATile<U, M, K>& A,
344 thread MMATile<U, K, N>& B,
345 thread MMATile<T, M, N>& C) {
347 for (short m = 0; m < M; ++m) {
349 for (short n = 0; n < N; ++n) {
350 short n_serp = (m % 2) ? (N - 1 - n) : n;
352 for (short k = 0; k < K; ++k) {
354 D.frag_at(m, n_serp),
355 A.frag_at(m, k),
356 B.frag_at(k, n_serp),
357 C.frag_at(m, n_serp));
358 }
359 }
360 }
361}
362
363template <
364 typename T,
365 typename U,
366 int BM,
367 int BN,
368 int BK,
369 int WM,
370 int WN,
371 bool transpose_a,
372 bool transpose_b,
373 short lda_tgp,
374 short ldb_tgp,
375 typename AccumType = float,
376 typename Epilogue = TransformNone<U, AccumType>>
377struct BlockMMA {
378 // MMAFrag size
381
382 // Warp tile simdgroup matrix strides along M
384 // Warp tile simdgroup matrix strides along M
386
387 // Warp tile size along M
389 // Warp tile size along N
391
392 // Threadgroup A strides
393 STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
394 STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
395
396 // Threadgroup B strides
397 STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
398 STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
399
400 // Threadgroup strides along K
403
404 // Simdgroup matrices
408
409 // Offsets within threadgroup
410 short sm;
411 short sn;
412
415
416 /* Constructor */
417 METAL_FUNC BlockMMA(
418 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
419 ushort simd_lane_id [[thread_index_in_simdgroup]]) {
420 // Determine thread position in simdgroup matrix
421 short tm = kFragSize * (simd_group_id / WN);
422 short tn = kFragSize * (simd_group_id % WN);
423
424 short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
425 sm = simd_coord.y;
426 sn = simd_coord.x;
427
428 // Determine thread and simdgroup offset
429 As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K
430 Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N
431
432 sm += tm;
433 sn += tn;
434 }
435
436 /* (BM, BK) X (BK, BN) multiply accumulate function */
437 METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
438 // Adjust for simdgroup and thread location
439 As += As_offset;
440 Bs += Bs_offset;
441
442 // Iterate over BK in blocks of kFragSize
444 for (short kk = 0; kk < BK; kk += kFragSize) {
445 simdgroup_barrier(mem_flags::mem_none);
446
447 Atile.template load<T, WM, 1, A_str_m, A_str_k>(As);
448
449 simdgroup_barrier(mem_flags::mem_none);
450
451 Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs);
452
453 simdgroup_barrier(mem_flags::mem_none);
454
456
457 // Progress to next simdgroup tile
458 As += tile_stride_a;
459 Bs += tile_stride_b;
460 }
461 }
462
463 /* Store results from simdgroup_matrix results into device memory */
464 METAL_FUNC void store_result(device U* D, const int ldd) {
465 // Apply epilogue
467 for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
468 Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
469 }
470
471 // Adjust for simdgroup and thread location
472 D += sm * ldd + sn;
473
474 Ctile.template store<U, WM, WN>(D, ldd);
475 }
476
477 METAL_FUNC void
478 store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
479 // Apply epilogue
481 for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
482 Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
483 }
484
485 // Adjust for simdgroup and thread location
486 D += sm * ldd + sn;
487 dst_tile_dims -= short2(sn, sm);
488
489 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
490 return;
491
492 Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims);
493 }
494
495 /* Apply epilogue */
496 template <typename UnaryEpilogue>
497 METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
498 // Loop over all simdgroup tiles
500 for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
501 Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]);
502 }
503 }
504
505 /* Apply epilogue */
506 template <typename BinaryEpilogue>
507 METAL_FUNC void apply_epilogue(
508 const device U* C,
509 const int ldc,
510 const int fdc,
511 thread const BinaryEpilogue& epilogue_op) {
512 // Adjust for simdgroup and thread location
513 C += (sm)*ldc + (sn)*fdc;
514
515 // Loop over all simdgroup tiles
517 for (short i = 0; i < TM; i++) {
519 for (short j = 0; j < TN; j++) {
520 // Get accumulated result and associated offset in C
521 thread auto& accum = Ctile.frag_at(i, j);
522 int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
523
524 // Apply epilogue
526 for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) {
527 accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
528 }
529 }
530 }
531 }
532
533 /* Apply epilogue */
534 template <typename BinaryEpilogue>
535 METAL_FUNC void apply_epilogue_safe(
536 const device U* C,
537 const int ldc,
538 const int fdc,
539 short2 dst_tile_dims,
540 thread const BinaryEpilogue& epilogue_op) {
541 // Adjust for simdgroup and thread location
542 C += (sm)*ldc + (sn)*fdc;
543 dst_tile_dims -= short2(sn, sm);
544
545 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
546 return;
547
548 // Loop over all simdgroup tiles
550 for (short i = 0; i < TM; i++) {
552 for (short j = 0; j < TN; j++) {
553 // Get accumulated result and associated offset in C
554 thread auto& accum = Ctile.frag_at(i, j);
555 int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
556
557 constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
558
559 // Read C
560 U c_elems[kelems] = {0};
561
563 for (short k = 0; k < kelems; k++) {
564 if ((j * TN_stride + k) < dst_tile_dims.x) {
565 c_elems[k] = C[offset_c + k * fdc];
566 }
567 }
568
569 // Apply epilogue
571 for (short k = 0; k < kelems; k++) {
572 accum[k] = epilogue_op.apply(accum[k], c_elems[k]);
573 }
574 }
575 }
576 }
577
578 /* Store results from simdgroup_matrix results into device memory */
579 METAL_FUNC void store_result(
580 device U* D,
581 const int ldd,
582 const device U* C,
583 const int ldc,
584 const int fdc,
585 thread const Epilogue& epilogue_op) const {
586 // Adjust for simdgroup and thread location
587 C += (sm)*ldc + (sn)*fdc;
588 D += (sm)*ldd + sn;
589
590 constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
591
592 // Loop over all simdgroup tiles
594 for (short i = 0; i < TM; i++) {
596 for (short j = 0; j < TN; j++) {
597 // Get accumulated result and associated offset in C
598 thread const auto& accum = Ctile.frag_at(i, j);
599 int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
600 int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
601
602 // Apply epilogue
604 for (short k = 0; k < kelems; k++) {
605 D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
606 }
607 }
608 }
609 }
610
611 METAL_FUNC void store_result_safe(
612 device U* D,
613 const int ldd,
614 const device U* C,
615 const int ldc,
616 const int fdc,
617 short2 dst_tile_dims,
618 thread const Epilogue& epilogue_op) const {
619 // Adjust for simdgroup and thread location
620 C += (sm)*ldc + (sn)*fdc;
621 D += (sm)*ldd + sn;
622 dst_tile_dims -= short2(sn, sm);
623
624 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
625 return;
626
627 constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
628
630 for (int i = 0; i < TM; i++) {
631 if (i * TM_stride < dst_tile_dims.y) {
633 for (int j = 0; j < TN; j++) {
634 // Get accumulated result and associated offset in C
635 thread const auto& accum = Ctile.frag_at(i, j);
636 int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
637 int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
638
639 // Apply epilogue
641 for (short k = 0; k < kelems; k++) {
642 if ((j * TN_stride + k) < dst_tile_dims.x) {
643 D[offset_d + k] =
644 epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
645 }
646 }
647 }
648 }
649 }
650 }
651};
652
653} // namespace steel
654} // namespace mlx
Definition bf16.h:265
typename pointer_element< remove_cv_t< T > >::type pointer_element_t
Definition type_traits.h:51
METAL_FUNC void tile_matmad(thread MMATile< T, M, N > &D, thread MMATile< U, M, K > &A, thread MMATile< U, K, N > &B, thread MMATile< T, M, N > &C)
Definition mma.h:341
Definition allocator.h:7
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
#define STEEL_CONST
Definition defines.h:3
static METAL_FUNC constexpr void mma(thread mat_type &D, thread mat_type &A, thread mat_type &B, thread mat_type &C)
Definition mma.h:164
static METAL_FUNC constexpr void store_safe(const thread frag_type &src, DstPtrType dst, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x=Int< 0 >{}, OffY off_y=Int< 0 >{})
Definition mma.h:122
metal::simdgroup_matrix< T, kFragRows, kFragCols > mat_type
Definition mma.h:46
static METAL_FUNC constexpr short2 get_coord(ushort simd_lane_id)
Definition mma.h:49
static METAL_FUNC constexpr void mma(thread frag_type &D, thread frag_type &A, thread frag_type &B, thread frag_type &C)
Definition mma.h:145
static METAL_FUNC constexpr void store(const thread frag_type &src, DstPtrType dst, StrX str_x, StrY str_y)
Definition mma.h:102
static METAL_FUNC constexpr void load(thread frag_type &dst, SrcPtrType src, StrX str_x, StrY str_y)
Definition mma.h:59
static METAL_FUNC constexpr void load_safe(thread frag_type &dst, SrcPtrType src, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x=Int< 0 >{}, OffY off_y=Int< 0 >{})
Definition mma.h:77
metal::vec< T, kElemsPerFrag > frag_type
Definition mma.h:47
Definition mma.h:23
Definition mma.h:377
METAL_FUNC void store_result(device U *D, const int ldd)
Definition mma.h:464
METAL_FUNC void store_result_safe(device U *D, const int ldd, short2 dst_tile_dims)
Definition mma.h:478
short As_offset
Definition mma.h:413
MMATile< AccumType, 1, TN, MMAFrag_acc_t > Btile
Definition mma.h:406
STEEL_CONST short A_str_k
Definition mma.h:394
STEEL_CONST short B_str_n
Definition mma.h:398
STEEL_CONST short TM_stride
Definition mma.h:383
METAL_FUNC void mma(const threadgroup T *As, const threadgroup T *Bs)
Definition mma.h:437
STEEL_CONST short TN
Definition mma.h:390
METAL_FUNC void store_result_safe(device U *D, const int ldd, const device U *C, const int ldc, const int fdc, short2 dst_tile_dims, thread const Epilogue &epilogue_op) const
Definition mma.h:611
METAL_FUNC void store_result(device U *D, const int ldd, const device U *C, const int ldc, const int fdc, thread const Epilogue &epilogue_op) const
Definition mma.h:579
MMATile< AccumType, TM, TN, MMAFrag_acc_t > Ctile
Definition mma.h:407
METAL_FUNC void apply_epilogue(const device U *C, const int ldc, const int fdc, thread const BinaryEpilogue &epilogue_op)
Definition mma.h:507
STEEL_CONST short TN_stride
Definition mma.h:385
STEEL_CONST short tile_stride_a
Definition mma.h:401
short Bs_offset
Definition mma.h:414
METAL_FUNC void apply_epilogue_safe(const device U *C, const int ldc, const int fdc, short2 dst_tile_dims, thread const BinaryEpilogue &epilogue_op)
Definition mma.h:535
METAL_FUNC BlockMMA(ushort simd_group_id, ushort simd_lane_id)
Definition mma.h:417
STEEL_CONST short B_str_k
Definition mma.h:397
short sm
Definition mma.h:410
STEEL_CONST short A_str_m
Definition mma.h:393
STEEL_CONST short TM
Definition mma.h:388
short sn
Definition mma.h:411
STEEL_CONST short tile_stride_b
Definition mma.h:402
STEEL_CONST short kFragSize
Definition mma.h:379
MMATile< AccumType, TM, 1, MMAFrag_acc_t > Atile
Definition mma.h:405
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue &epilogue_op)
Definition mma.h:497
Definition mma.h:178
METAL_FUNC constexpr thread frag_type & frag_at(const short i, const short j)
Definition mma.h:208
STEEL_CONST int kTileRows
Definition mma.h:185
MMAFrag_t::mat_type mat_type
Definition mma.h:194
METAL_FUNC void store(threadgroup U *dst) const
Definition mma.h:253
METAL_FUNC mat_type mat_at(const short i, const short j)
Definition mma.h:218
STEEL_CONST int kTileCols
Definition mma.h:186
METAL_FUNC void store_safe(device U *dst, const int ld, const short2 dst_tile_dims) const
Definition mma.h:321
STEEL_CONST int kFragRows
Definition mma.h:181
MMAFrag_t::frag_type frag_type
Definition mma.h:195
STEEL_CONST int kRows
Definition mma.h:188
METAL_FUNC void store(device U *dst, const int ld) const
Definition mma.h:285
T elem_type
Definition mma.h:180
METAL_FUNC thread elem_type * elems()
Definition mma.h:227
STEEL_CONST int kCols
Definition mma.h:189
STEEL_CONST int kElemsPerTile
Definition mma.h:192
METAL_FUNC void load_safe(const device U *src, const int ld, const short2 src_tile_dims)
Definition mma.h:301
METAL_FUNC MMATile() thread
Definition mma.h:199
METAL_FUNC void load(const threadgroup U *src)
Definition mma.h:236
METAL_FUNC constexpr void clear()
Definition mma.h:201
METAL_FUNC void load(const device U *src, const int ld)
Definition mma.h:270
MMAFrag_ MMAFrag_t
Definition mma.h:179
frag_type val_frags[kNumFrags]
Definition mma.h:197
STEEL_CONST int kFragCols
Definition mma.h:182
METAL_FUNC constexpr const thread frag_type & frag_at(const short i, const short j) const
Definition mma.h:212
METAL_FUNC const thread elem_type * elems() const
Definition mma.h:231
STEEL_CONST int kNumFrags
Definition mma.h:191
STEEL_CONST int kElemsPerFrag
Definition mma.h:183
Definition integral_constant.h:18