5#include <metal_simdgroup>
6#include <metal_simdgroup_matrix>
22template <
typename RInt,
typename CInt>
30template <
typename Shape,
typename Layout>
36template <
typename T,
int kFragRows_,
int kFragCols_>
40 "Only 8 x 8 fragment matrices are currently supported");
43 "Only 8 x 8 fragment matrices are currently supported");
47struct BaseMMAFrag<T, 8, 8> {
57 kElemRows * kElemCols == kElemsPerFrag,
58 "MMAFrag shape is not consistent with MMAFrag size");
60 typedef metal::simdgroup_matrix<T, kFragRows, kFragCols>
mat_type;
65 METAL_FUNC
static constexpr short2
get_coord(ushort simd_lane_id
66 [[thread_index_in_simdgroup]]) {
67 const short qid = simd_lane_id / 4;
68 const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);
69 const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
70 return short2{fn, fm};
73 template <
typename SrcPtrType,
typename StrX,
typename StrY>
74 METAL_FUNC
static constexpr void
77 for (
short i = 0; i < kElemRows; i++) {
79 for (
short j = 0; j < kElemCols; j++) {
80 dst[i * kElemCols + j] =
static_cast<T
>(src[i * str_x + j * str_y]);
103 for (
short i = 0; i < kElemRows; i++) {
105 for (
short j = 0; j < kElemCols; j++) {
106 if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
107 dst[i * kElemCols + j] =
108 static_cast<T
>(src[(off_x + i) * str_x + (off_x + j) * str_y]);
110 dst[i * kElemCols + j] = T(0);
116 template <
typename DstPtrType,
typename StrX,
typename StrY>
117 METAL_FUNC
static constexpr void
122 for (
short i = 0; i < kElemRows; i++) {
124 for (
short j = 0; j < kElemCols; j++) {
125 dst[i * str_x + j * str_y] =
static_cast<U
>(src[i * kElemCols + j]);
150 for (
short i = 0; i < kElemRows; i++) {
152 for (
short j = 0; j < kElemCols; j++) {
153 if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
154 dst[(off_x + i) * str_x + (off_y + j) * str_y] =
155 static_cast<U
>(src[i * kElemCols + j]);
161 METAL_FUNC
static constexpr void mma(
171 reinterpret_cast<thread
frag_type&
>(A_mat.thread_elements()) = A;
172 reinterpret_cast<thread
frag_type&
>(B_mat.thread_elements()) = B;
173 reinterpret_cast<thread
frag_type&
>(C_mat.thread_elements()) = C;
175 mma(D_mat, A_mat, B_mat, C_mat);
177 D =
reinterpret_cast<thread
frag_type&
>(D_mat.thread_elements());
180 METAL_FUNC
static constexpr void mma(
185 simdgroup_multiply_accumulate(D, A, B, C);
188 template <
typename Op>
191 thread T* reduced_vals) {
192 T thr_reduce = Op::apply(inp_vals.x, inp_vals.y);
195 qgr_reduce = Op::apply(thr_reduce, qgr_reduce);
198 sgr_reduce = Op::apply(qgr_reduce, sgr_reduce);
200 reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce);
203 template <
typename Op>
206 thread T* row_vals) {
208 for (
short i = 0; i < kElemRows; i++) {
210 for (
short j = 0; j < kElemCols; j++) {
211 inp_vals[i * kElemCols + j] =
212 Op::apply(inp_vals[i * kElemCols + j], row_vals[i]);
222 class MMAFrag_ = BaseMMAFrag<T, 8, 8>>
262 const short j)
const {
270 val_mat.thread_elements()[ii] =
frag_at(i, j)[ii];
283 template <
typename Op>
290 frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
295 template <
typename Op>
302 frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
307 template <
typename U,
int w_x,
int w_y,
int str_x,
int str_y>
308 METAL_FUNC
void load(
const threadgroup U* src) {
324 template <
typename U,
int w_x,
int w_y,
int str_x,
int str_y>
325 METAL_FUNC
void store(threadgroup U* dst)
const {
341 template <
typename U,
int w_x,
int w_y>
342 METAL_FUNC
void load(
const device U* src,
const int ld) {
356 template <
typename U,
int w_x,
int w_y>
357 METAL_FUNC
void store(device U* dst,
const int ld)
const {
371 template <
typename U,
int w_x,
int w_y>
373 load_safe(
const device U* src,
const int ld,
const short2 src_tile_dims) {
378 MMAFrag_t::load_safe(
391 template <
typename U,
int w_x,
int w_y>
393 store_safe(device U* dst,
const int ld,
const short2 dst_tile_dims)
const {
398 MMAFrag_t::store_safe(
412template <
typename T,
typename U,
int M,
int N,
int K>
419 for (
short k = 0; k < K; ++k) {
421 for (
short m = 0; m < M; ++m) {
423 for (
short n = 0; n < N; ++n) {
424 short n_serp = (m % 2) ? (N - 1 - n) : n;
426 D.frag_at(m, n_serp),
428 B.frag_at(k, n_serp),
429 C.frag_at(m, n_serp));
447 typename AccumType = float,
448 typename Epilogue = TransformNone<U, AccumType>>
490 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
491 ushort simd_lane_id [[thread_index_in_simdgroup]]) {
493 short tm =
kFragSize * (simd_group_id / WN);
494 short tn =
kFragSize * (simd_group_id % WN);
496 short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
509 METAL_FUNC
void mma(
const threadgroup T* As,
const threadgroup T* Bs) {
516 for (
short kk = 0; kk < BK; kk +=
kFragSize) {
517 simdgroup_barrier(mem_flags::mem_none);
519 Atile.template load<T, WM, 1, A_str_m, A_str_k>(As);
521 simdgroup_barrier(mem_flags::mem_none);
523 Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs);
525 simdgroup_barrier(mem_flags::mem_none);
539 for (
short i = 0; i <
decltype(
Ctile)::kElemsPerTile; i++) {
540 Ctile.elems()[i] = Epilogue::apply(
Ctile.elems()[i]);
546 Ctile.template store<U, WM, WN>(D, ldd);
553 for (
short i = 0; i <
decltype(
Ctile)::kElemsPerTile; i++) {
554 Ctile.elems()[i] = Epilogue::apply(
Ctile.elems()[i]);
559 dst_tile_dims -= short2(
sn,
sm);
561 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
564 Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims);
568 template <
typename UnaryEpilogue>
572 for (
short i = 0; i <
decltype(
Ctile)::kElemsPerTile; i++) {
573 Ctile.elems()[i] = epilogue_op.apply(
Ctile.elems()[i]);
578 template <
typename BinaryEpilogue>
583 thread
const BinaryEpilogue& epilogue_op) {
585 C += (
sm)*ldc + (
sn)*fdc;
589 for (
short i = 0; i <
TM; i++) {
591 for (
short j = 0; j <
TN; j++) {
593 thread
auto& accum =
Ctile.frag_at(i, j);
598 for (
short k = 0; k <
decltype(
Ctile)::kElemsPerFrag; k++) {
599 accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
606 template <
typename BinaryEpilogue>
611 short2 dst_tile_dims,
612 thread
const BinaryEpilogue& epilogue_op) {
614 C += (
sm)*ldc + (
sn)*fdc;
615 dst_tile_dims -= short2(
sn,
sm);
617 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
622 for (
short i = 0; i <
TM; i++) {
624 for (
short j = 0; j <
TN; j++) {
626 thread
auto& accum =
Ctile.frag_at(i, j);
629 constexpr short kelems =
decltype(
Ctile)::kElemsPerFrag;
632 U c_elems[kelems] = {0};
635 for (
short k = 0; k < kelems; k++) {
636 if ((j *
TN_stride + k) < dst_tile_dims.x) {
637 c_elems[k] = C[offset_c + k * fdc];
643 for (
short k = 0; k < kelems; k++) {
644 accum[k] = epilogue_op.apply(accum[k], c_elems[k]);
657 thread
const Epilogue& epilogue_op)
const {
659 C += (
sm)*ldc + (
sn)*fdc;
662 constexpr short kelems =
decltype(
Ctile)::kElemsPerFrag;
666 for (
short i = 0; i <
TM; i++) {
668 for (
short j = 0; j <
TN; j++) {
670 thread
const auto& accum =
Ctile.frag_at(i, j);
676 for (
short k = 0; k < kelems; k++) {
677 D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
689 short2 dst_tile_dims,
690 thread
const Epilogue& epilogue_op)
const {
692 C += (
sm)*ldc + (
sn)*fdc;
694 dst_tile_dims -= short2(
sn,
sm);
696 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
699 constexpr short kelems =
decltype(
Ctile)::kElemsPerFrag;
702 for (
int i = 0; i <
TM; i++) {
705 for (
int j = 0; j <
TN; j++) {
707 thread
const auto& accum =
Ctile.frag_at(i, j);
713 for (
short k = 0; k < kelems; k++) {
714 if ((j *
TN_stride + k) < dst_tile_dims.x) {
716 epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
METAL_FUNC void tile_matmad(thread MMATile< T, M, N > &D, thread MMATile< U, M, K > &A, thread MMATile< U, K, N > &B, thread MMATile< T, M, N > &C)
Definition mma.h:413
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
#define STEEL_CONST
Definition defines.h:3
static METAL_FUNC constexpr void mma(thread mat_type &D, thread mat_type &A, thread mat_type &B, thread mat_type &C)
Definition mma.h:180
static METAL_FUNC constexpr void store_safe(const thread frag_type &src, DstPtrType dst, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x=Int< 0 >{}, OffY off_y=Int< 0 >{})
Definition mma.h:138
static METAL_FUNC constexpr void row_bin_op(thread frag_type &inp_vals, thread T *row_vals)
Definition mma.h:204
static METAL_FUNC constexpr void row_reduce(thread const frag_type &inp_vals, thread T *reduced_vals)
Definition mma.h:189
metal::vec< T, kElemRows > row_frag_type
Definition mma.h:62
static METAL_FUNC constexpr short2 get_coord(ushort simd_lane_id)
Definition mma.h:65
static METAL_FUNC constexpr void mma(thread frag_type &D, thread frag_type &A, thread frag_type &B, thread frag_type &C)
Definition mma.h:161
metal::simdgroup_matrix< T, kFragRows, kFragCols > mat_type
Definition mma.h:60
metal::vec< T, kElemsPerFrag > frag_type
Definition mma.h:61
static METAL_FUNC constexpr void store(const thread frag_type &src, DstPtrType dst, StrX str_x, StrY str_y)
Definition mma.h:118
metal::vec< T, kElemCols > col_frag_type
Definition mma.h:63
static METAL_FUNC constexpr void load(thread frag_type &dst, SrcPtrType src, StrX str_x, StrY str_y)
Definition mma.h:75
static METAL_FUNC constexpr void load_safe(thread frag_type &dst, SrcPtrType src, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x=Int< 0 >{}, OffY off_y=Int< 0 >{})
Definition mma.h:93
METAL_FUNC void store_result(device U *D, const int ldd)
Definition mma.h:536
METAL_FUNC void store_result_safe(device U *D, const int ldd, short2 dst_tile_dims)
Definition mma.h:550
short As_offset
Definition mma.h:485
MMATile< AccumType, TM, TN, MMAFrag_acc_t > Ctile
Definition mma.h:479
STEEL_CONST short A_str_k
Definition mma.h:466
MMATile< AccumType, 1, TN, MMAFrag_acc_t > Btile
Definition mma.h:478
MMATile< AccumType, TM, 1, MMAFrag_acc_t > Atile
Definition mma.h:477
STEEL_CONST short B_str_n
Definition mma.h:470
STEEL_CONST short TM_stride
Definition mma.h:455
METAL_FUNC void mma(const threadgroup T *As, const threadgroup T *Bs)
Definition mma.h:509
STEEL_CONST short TN
Definition mma.h:462
METAL_FUNC void store_result_safe(device U *D, const int ldd, const device U *C, const int ldc, const int fdc, short2 dst_tile_dims, thread const Epilogue &epilogue_op) const
Definition mma.h:683
METAL_FUNC void store_result(device U *D, const int ldd, const device U *C, const int ldc, const int fdc, thread const Epilogue &epilogue_op) const
Definition mma.h:651
METAL_FUNC void apply_epilogue(const device U *C, const int ldc, const int fdc, thread const BinaryEpilogue &epilogue_op)
Definition mma.h:579
STEEL_CONST short TN_stride
Definition mma.h:457
STEEL_CONST short tile_stride_a
Definition mma.h:473
short Bs_offset
Definition mma.h:486
METAL_FUNC void apply_epilogue_safe(const device U *C, const int ldc, const int fdc, short2 dst_tile_dims, thread const BinaryEpilogue &epilogue_op)
Definition mma.h:607
METAL_FUNC BlockMMA(ushort simd_group_id, ushort simd_lane_id)
Definition mma.h:489
STEEL_CONST short B_str_k
Definition mma.h:469
short sm
Definition mma.h:482
STEEL_CONST short A_str_m
Definition mma.h:465
STEEL_CONST short TM
Definition mma.h:460
short sn
Definition mma.h:483
STEEL_CONST short tile_stride_b
Definition mma.h:474
STEEL_CONST short kFragSize
Definition mma.h:451
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue &epilogue_op)
Definition mma.h:569
Shape shape
Definition mma.h:32
Layout layout
Definition mma.h:33
METAL_FUNC constexpr thread frag_type & frag_at(const short i, const short j)
Definition mma.h:256
STEEL_CONST int kTileRows
Definition mma.h:230
STEEL_CONST int kColsPerThread
Definition mma.h:240
MMAFrag_t::mat_type mat_type
Definition mma.h:242
METAL_FUNC void store(threadgroup U *dst) const
Definition mma.h:325
METAL_FUNC mat_type mat_at(const short i, const short j)
Definition mma.h:266
METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread])
Definition mma.h:296
STEEL_CONST int kTileCols
Definition mma.h:231
METAL_FUNC void store_safe(device U *dst, const int ld, const short2 dst_tile_dims) const
Definition mma.h:393
STEEL_CONST int kFragRows
Definition mma.h:226
STEEL_CONST int kRowsPerThread
Definition mma.h:239
STEEL_CONST int kRows
Definition mma.h:233
frag_type val_frags[kNumFrags]
Definition mma.h:245
MMAFrag_ MMAFrag_t
Definition mma.h:224
METAL_FUNC void store(device U *dst, const int ld) const
Definition mma.h:357
T elem_type
Definition mma.h:225
METAL_FUNC thread elem_type * elems()
Definition mma.h:275
STEEL_CONST int kCols
Definition mma.h:234
STEEL_CONST int kElemsPerTile
Definition mma.h:237
METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const
Definition mma.h:284
METAL_FUNC void load_safe(const device U *src, const int ld, const short2 src_tile_dims)
Definition mma.h:373
METAL_FUNC MMATile() thread
Definition mma.h:247
METAL_FUNC void load(const threadgroup U *src)
Definition mma.h:308
METAL_FUNC constexpr void clear()
Definition mma.h:249
METAL_FUNC void load(const device U *src, const int ld)
Definition mma.h:342
MMAFrag_t::frag_type frag_type
Definition mma.h:243
STEEL_CONST int kFragCols
Definition mma.h:227
METAL_FUNC constexpr const thread frag_type & frag_at(const short i, const short j) const
Definition mma.h:260
METAL_FUNC const thread elem_type * elems() const
Definition mma.h:279
STEEL_CONST int kNumFrags
Definition mma.h:236
STEEL_CONST int kElemsPerFrag
Definition mma.h:228
Shape2D(RInt r_, CInt c_)
Definition mma.h:27
RInt r
Definition mma.h:24
CInt c
Definition mma.h:25
Definition integral_constant.h:18