5#include <metal_simdgroup>
6#include <metal_simdgroup_matrix>
22template <
typename T,
int kFragRows_,
int kFragCols_>
26 "Only 8 x 8 fragment matrices are currently supported");
29 "Only 8 x 8 fragment matrices are currently supported");
43 kElemRows * kElemCols == kElemsPerFrag,
44 "MMAFrag shape is not consistent with MMAFrag size");
46 typedef metal::simdgroup_matrix<T, kFragRows, kFragCols>
mat_type;
49 METAL_FUNC
static constexpr short2
get_coord(ushort simd_lane_id
50 [[thread_index_in_simdgroup]]) {
51 const short qid = simd_lane_id / 4;
52 const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);
53 const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
54 return short2{fn, fm};
57 template <
typename SrcPtrType,
typename StrX,
typename StrY>
58 METAL_FUNC
static constexpr void
61 for (
short i = 0; i < kElemRows; i++) {
63 for (
short j = 0; j < kElemCols; j++) {
64 dst[i * kElemCols + j] =
static_cast<T
>(src[i * str_x + j * str_y]);
87 for (
short i = 0; i < kElemRows; i++) {
89 for (
short j = 0; j < kElemCols; j++) {
90 if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
91 dst[i * kElemCols + j] =
92 static_cast<T
>(src[(off_x + i) * str_x + (off_x + j) * str_y]);
94 dst[i * kElemCols + j] = T(0);
100 template <
typename DstPtrType,
typename StrX,
typename StrY>
101 METAL_FUNC
static constexpr void
106 for (
short i = 0; i < kElemRows; i++) {
108 for (
short j = 0; j < kElemCols; j++) {
109 dst[i * str_x + j * str_y] =
static_cast<U
>(src[i * kElemCols + j]);
134 for (
short i = 0; i < kElemRows; i++) {
136 for (
short j = 0; j < kElemCols; j++) {
137 if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
138 dst[(off_x + i) * str_x + (off_y + j) * str_y] =
139 static_cast<U
>(src[i * kElemCols + j]);
145 METAL_FUNC
static constexpr void mma(
155 reinterpret_cast<thread
frag_type&
>(A_mat.thread_elements()) = A;
156 reinterpret_cast<thread
frag_type&
>(B_mat.thread_elements()) = B;
157 reinterpret_cast<thread
frag_type&
>(C_mat.thread_elements()) = C;
159 mma(D_mat, A_mat, B_mat, C_mat);
161 D =
reinterpret_cast<thread
frag_type&
>(D_mat.thread_elements());
164 METAL_FUNC
static constexpr void mma(
169 simdgroup_multiply_accumulate(D, A, B, C);
177 class MMAFrag_ = BaseMMAFrag<T, 8, 8>>
214 const short j)
const {
222 val_mat.thread_elements()[ii] =
frag_at(i, j)[ii];
235 template <
typename U,
int w_x,
int w_y,
int str_x,
int str_y>
236 METAL_FUNC
void load(
const threadgroup U* src) {
252 template <
typename U,
int w_x,
int w_y,
int str_x,
int str_y>
253 METAL_FUNC
void store(threadgroup U* dst)
const {
269 template <
typename U,
int w_x,
int w_y>
270 METAL_FUNC
void load(
const device U* src,
const int ld) {
284 template <
typename U,
int w_x,
int w_y>
285 METAL_FUNC
void store(device U* dst,
const int ld)
const {
299 template <
typename U,
int w_x,
int w_y>
301 load_safe(
const device U* src,
const int ld,
const short2 src_tile_dims) {
306 MMAFrag_t::load_safe(
319 template <
typename U,
int w_x,
int w_y>
321 store_safe(device U* dst,
const int ld,
const short2 dst_tile_dims)
const {
326 MMAFrag_t::store_safe(
340template <
typename T,
typename U,
int M,
int N,
int K>
347 for (
short m = 0; m < M; ++m) {
349 for (
short n = 0; n < N; ++n) {
350 short n_serp = (m % 2) ? (N - 1 - n) : n;
352 for (
short k = 0; k < K; ++k) {
354 D.frag_at(m, n_serp),
356 B.frag_at(k, n_serp),
357 C.frag_at(m, n_serp));
375 typename AccumType = float,
376 typename Epilogue = TransformNone<U, AccumType>>
418 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
419 ushort simd_lane_id [[thread_index_in_simdgroup]]) {
421 short tm =
kFragSize * (simd_group_id / WN);
422 short tn =
kFragSize * (simd_group_id % WN);
424 short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
437 METAL_FUNC
void mma(
const threadgroup T* As,
const threadgroup T* Bs) {
444 for (
short kk = 0; kk < BK; kk +=
kFragSize) {
445 simdgroup_barrier(mem_flags::mem_none);
447 Atile.template load<T, WM, 1, A_str_m, A_str_k>(As);
449 simdgroup_barrier(mem_flags::mem_none);
451 Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs);
453 simdgroup_barrier(mem_flags::mem_none);
467 for (
short i = 0; i <
decltype(
Ctile)::kElemsPerTile; i++) {
468 Ctile.elems()[i] = Epilogue::apply(
Ctile.elems()[i]);
474 Ctile.template store<U, WM, WN>(D, ldd);
481 for (
short i = 0; i <
decltype(
Ctile)::kElemsPerTile; i++) {
482 Ctile.elems()[i] = Epilogue::apply(
Ctile.elems()[i]);
487 dst_tile_dims -= short2(
sn,
sm);
489 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
492 Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims);
496 template <
typename UnaryEpilogue>
500 for (
short i = 0; i <
decltype(
Ctile)::kElemsPerTile; i++) {
501 Ctile.elems()[i] = epilogue_op.apply(
Ctile.elems()[i]);
506 template <
typename BinaryEpilogue>
511 thread
const BinaryEpilogue& epilogue_op) {
513 C += (
sm)*ldc + (
sn)*fdc;
517 for (
short i = 0; i <
TM; i++) {
519 for (
short j = 0; j <
TN; j++) {
521 thread
auto& accum =
Ctile.frag_at(i, j);
526 for (
short k = 0; k <
decltype(
Ctile)::kElemsPerFrag; k++) {
527 accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
534 template <
typename BinaryEpilogue>
539 short2 dst_tile_dims,
540 thread
const BinaryEpilogue& epilogue_op) {
542 C += (
sm)*ldc + (
sn)*fdc;
543 dst_tile_dims -= short2(
sn,
sm);
545 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
550 for (
short i = 0; i <
TM; i++) {
552 for (
short j = 0; j <
TN; j++) {
554 thread
auto& accum =
Ctile.frag_at(i, j);
557 constexpr short kelems =
decltype(
Ctile)::kElemsPerFrag;
560 U c_elems[kelems] = {0};
563 for (
short k = 0; k < kelems; k++) {
564 if ((j *
TN_stride + k) < dst_tile_dims.x) {
565 c_elems[k] = C[offset_c + k * fdc];
571 for (
short k = 0; k < kelems; k++) {
572 accum[k] = epilogue_op.apply(accum[k], c_elems[k]);
585 thread
const Epilogue& epilogue_op)
const {
587 C += (
sm)*ldc + (
sn)*fdc;
590 constexpr short kelems =
decltype(
Ctile)::kElemsPerFrag;
594 for (
short i = 0; i <
TM; i++) {
596 for (
short j = 0; j <
TN; j++) {
598 thread
const auto& accum =
Ctile.frag_at(i, j);
604 for (
short k = 0; k < kelems; k++) {
605 D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
617 short2 dst_tile_dims,
618 thread
const Epilogue& epilogue_op)
const {
620 C += (
sm)*ldc + (
sn)*fdc;
622 dst_tile_dims -= short2(
sn,
sm);
624 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
627 constexpr short kelems =
decltype(
Ctile)::kElemsPerFrag;
630 for (
int i = 0; i <
TM; i++) {
633 for (
int j = 0; j <
TN; j++) {
635 thread
const auto& accum =
Ctile.frag_at(i, j);
641 for (
short k = 0; k < kelems; k++) {
642 if ((j *
TN_stride + k) < dst_tile_dims.x) {
644 epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
METAL_FUNC void tile_matmad(thread MMATile< T, M, N > &D, thread MMATile< U, M, K > &A, thread MMATile< U, K, N > &B, thread MMATile< T, M, N > &C)
Definition mma.h:341
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
#define STEEL_CONST
Definition defines.h:3
static METAL_FUNC constexpr void mma(thread mat_type &D, thread mat_type &A, thread mat_type &B, thread mat_type &C)
Definition mma.h:164
static METAL_FUNC constexpr void store_safe(const thread frag_type &src, DstPtrType dst, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x=Int< 0 >{}, OffY off_y=Int< 0 >{})
Definition mma.h:122
metal::simdgroup_matrix< T, kFragRows, kFragCols > mat_type
Definition mma.h:46
static METAL_FUNC constexpr short2 get_coord(ushort simd_lane_id)
Definition mma.h:49
static METAL_FUNC constexpr void mma(thread frag_type &D, thread frag_type &A, thread frag_type &B, thread frag_type &C)
Definition mma.h:145
static METAL_FUNC constexpr void store(const thread frag_type &src, DstPtrType dst, StrX str_x, StrY str_y)
Definition mma.h:102
static METAL_FUNC constexpr void load(thread frag_type &dst, SrcPtrType src, StrX str_x, StrY str_y)
Definition mma.h:59
static METAL_FUNC constexpr void load_safe(thread frag_type &dst, SrcPtrType src, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x=Int< 0 >{}, OffY off_y=Int< 0 >{})
Definition mma.h:77
metal::vec< T, kElemsPerFrag > frag_type
Definition mma.h:47
METAL_FUNC void store_result(device U *D, const int ldd)
Definition mma.h:464
METAL_FUNC void store_result_safe(device U *D, const int ldd, short2 dst_tile_dims)
Definition mma.h:478
short As_offset
Definition mma.h:413
MMATile< AccumType, 1, TN, MMAFrag_acc_t > Btile
Definition mma.h:406
STEEL_CONST short A_str_k
Definition mma.h:394
STEEL_CONST short B_str_n
Definition mma.h:398
STEEL_CONST short TM_stride
Definition mma.h:383
METAL_FUNC void mma(const threadgroup T *As, const threadgroup T *Bs)
Definition mma.h:437
STEEL_CONST short TN
Definition mma.h:390
METAL_FUNC void store_result_safe(device U *D, const int ldd, const device U *C, const int ldc, const int fdc, short2 dst_tile_dims, thread const Epilogue &epilogue_op) const
Definition mma.h:611
METAL_FUNC void store_result(device U *D, const int ldd, const device U *C, const int ldc, const int fdc, thread const Epilogue &epilogue_op) const
Definition mma.h:579
MMATile< AccumType, TM, TN, MMAFrag_acc_t > Ctile
Definition mma.h:407
METAL_FUNC void apply_epilogue(const device U *C, const int ldc, const int fdc, thread const BinaryEpilogue &epilogue_op)
Definition mma.h:507
STEEL_CONST short TN_stride
Definition mma.h:385
STEEL_CONST short tile_stride_a
Definition mma.h:401
short Bs_offset
Definition mma.h:414
METAL_FUNC void apply_epilogue_safe(const device U *C, const int ldc, const int fdc, short2 dst_tile_dims, thread const BinaryEpilogue &epilogue_op)
Definition mma.h:535
METAL_FUNC BlockMMA(ushort simd_group_id, ushort simd_lane_id)
Definition mma.h:417
STEEL_CONST short B_str_k
Definition mma.h:397
short sm
Definition mma.h:410
STEEL_CONST short A_str_m
Definition mma.h:393
STEEL_CONST short TM
Definition mma.h:388
short sn
Definition mma.h:411
STEEL_CONST short tile_stride_b
Definition mma.h:402
STEEL_CONST short kFragSize
Definition mma.h:379
MMATile< AccumType, TM, 1, MMAFrag_acc_t > Atile
Definition mma.h:405
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue &epilogue_op)
Definition mma.h:497
METAL_FUNC constexpr thread frag_type & frag_at(const short i, const short j)
Definition mma.h:208
STEEL_CONST int kTileRows
Definition mma.h:185
MMAFrag_t::mat_type mat_type
Definition mma.h:194
METAL_FUNC void store(threadgroup U *dst) const
Definition mma.h:253
METAL_FUNC mat_type mat_at(const short i, const short j)
Definition mma.h:218
STEEL_CONST int kTileCols
Definition mma.h:186
METAL_FUNC void store_safe(device U *dst, const int ld, const short2 dst_tile_dims) const
Definition mma.h:321
STEEL_CONST int kFragRows
Definition mma.h:181
MMAFrag_t::frag_type frag_type
Definition mma.h:195
STEEL_CONST int kRows
Definition mma.h:188
METAL_FUNC void store(device U *dst, const int ld) const
Definition mma.h:285
T elem_type
Definition mma.h:180
METAL_FUNC thread elem_type * elems()
Definition mma.h:227
STEEL_CONST int kCols
Definition mma.h:189
STEEL_CONST int kElemsPerTile
Definition mma.h:192
METAL_FUNC void load_safe(const device U *src, const int ld, const short2 src_tile_dims)
Definition mma.h:301
METAL_FUNC MMATile() thread
Definition mma.h:199
METAL_FUNC void load(const threadgroup U *src)
Definition mma.h:236
METAL_FUNC constexpr void clear()
Definition mma.h:201
METAL_FUNC void load(const device U *src, const int ld)
Definition mma.h:270
MMAFrag_ MMAFrag_t
Definition mma.h:179
frag_type val_frags[kNumFrags]
Definition mma.h:197
STEEL_CONST int kFragCols
Definition mma.h:182
METAL_FUNC constexpr const thread frag_type & frag_at(const short i, const short j) const
Definition mma.h:212
METAL_FUNC const thread elem_type * elems() const
Definition mma.h:231
STEEL_CONST int kNumFrags
Definition mma.h:191
STEEL_CONST int kElemsPerFrag
Definition mma.h:183
Definition integral_constant.h:18