5#include <metal_simdgroup> 
    6#include <metal_simdgroup_matrix> 
   22template <
typename T, 
int kFragRows_, 
int kFragCols_>
 
   26      "Only 8 x 8 fragment matrices are currently supported");
 
   29      "Only 8 x 8 fragment matrices are currently supported");
 
 
   43      kElemRows * kElemCols == kElemsPerFrag,
 
   44      "MMAFrag shape is not consistent with MMAFrag size");
 
   46  typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> 
mat_type;
 
   49  METAL_FUNC 
static constexpr short2 
get_coord(ushort simd_lane_id
 
   50                                               [[thread_index_in_simdgroup]]) {
 
   51    const short qid = simd_lane_id / 4;
 
   52    const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);
 
   53    const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
 
   54    return short2{fn, fm};
 
 
   57  template <
typename SrcPtrType, 
typename StrX, 
typename StrY>
 
   58  METAL_FUNC 
static constexpr void 
   61    for (
short i = 0; i < kElemRows; i++) {
 
   63      for (
short j = 0; j < kElemCols; j++) {
 
   64        dst[i * kElemCols + j] = 
static_cast<T
>(src[i * str_x + j * str_y]);
 
 
   87    for (
short i = 0; i < kElemRows; i++) {
 
   89      for (
short j = 0; j < kElemCols; j++) {
 
   90        if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
 
   91          dst[i * kElemCols + j] =
 
   92              static_cast<T
>(src[(off_x + i) * str_x + (off_x + j) * str_y]);
 
   94          dst[i * kElemCols + j] = T(0);
 
 
  100  template <
typename DstPtrType, 
typename StrX, 
typename StrY>
 
  101  METAL_FUNC 
static constexpr void 
  106    for (
short i = 0; i < kElemRows; i++) {
 
  108      for (
short j = 0; j < kElemCols; j++) {
 
  109        dst[i * str_x + j * str_y] = 
static_cast<U
>(src[i * kElemCols + j]);
 
 
  134    for (
short i = 0; i < kElemRows; i++) {
 
  136      for (
short j = 0; j < kElemCols; j++) {
 
  137        if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
 
  138          dst[(off_x + i) * str_x + (off_y + j) * str_y] =
 
  139              static_cast<U
>(src[i * kElemCols + j]);
 
 
  145  METAL_FUNC 
static constexpr void mma(
 
  155    reinterpret_cast<thread 
frag_type&
>(A_mat.thread_elements()) = A;
 
  156    reinterpret_cast<thread 
frag_type&
>(B_mat.thread_elements()) = B;
 
  157    reinterpret_cast<thread 
frag_type&
>(C_mat.thread_elements()) = C;
 
  159    mma(D_mat, A_mat, B_mat, C_mat);
 
  161    D = 
reinterpret_cast<thread 
frag_type&
>(D_mat.thread_elements());
 
 
  164  METAL_FUNC 
static constexpr void mma(
 
  169    simdgroup_multiply_accumulate(D, A, B, C);
 
 
 
  177    class MMAFrag_ = BaseMMAFrag<T, 8, 8>>
 
  214      const short j)
 const {
 
 
  222      val_mat.thread_elements()[ii] = 
frag_at(i, j)[ii];
 
 
  235  template <
typename U, 
int w_x, 
int w_y, 
int str_x, 
int str_y>
 
  236  METAL_FUNC 
void load(
const threadgroup U* src) {
 
 
  252  template <
typename U, 
int w_x, 
int w_y, 
int str_x, 
int str_y>
 
  253  METAL_FUNC 
void store(threadgroup U* dst)
 const {
 
 
  269  template <
typename U, 
int w_x, 
int w_y>
 
  270  METAL_FUNC 
void load(
const device U* src, 
const int ld) {
 
 
  284  template <
typename U, 
int w_x, 
int w_y>
 
  285  METAL_FUNC 
void store(device U* dst, 
const int ld)
 const {
 
 
  299  template <
typename U, 
int w_x, 
int w_y>
 
  301  load_safe(
const device U* src, 
const int ld, 
const short2 src_tile_dims) {
 
  306        MMAFrag_t::load_safe(
 
 
  319  template <
typename U, 
int w_x, 
int w_y>
 
  321  store_safe(device U* dst, 
const int ld, 
const short2 dst_tile_dims)
 const {
 
  326        MMAFrag_t::store_safe(
 
 
 
  340template <
typename T, 
typename U, 
int M, 
int N, 
int K>
 
  347  for (
short m = 0; m < M; ++m) {
 
  349    for (
short n = 0; n < N; ++n) {
 
  350      short n_serp = (m % 2) ? (N - 1 - n) : n;
 
  352      for (
short k = 0; k < K; ++k) {
 
  354            D.frag_at(m, n_serp),
 
  356            B.frag_at(k, n_serp),
 
  357            C.frag_at(m, n_serp));
 
 
  375    typename AccumType = float,
 
  376    typename Epilogue = TransformNone<U, AccumType>>
 
  418      ushort simd_group_id [[simdgroup_index_in_threadgroup]],
 
  419      ushort simd_lane_id [[thread_index_in_simdgroup]]) {
 
  421    short tm = 
kFragSize * (simd_group_id / WN);
 
  422    short tn = 
kFragSize * (simd_group_id % WN);
 
  424    short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
 
 
  437  METAL_FUNC 
void mma(
const threadgroup T* As, 
const threadgroup T* Bs) {
 
  444    for (
short kk = 0; kk < BK; kk += 
kFragSize) {
 
  445      simdgroup_barrier(mem_flags::mem_none);
 
  447      Atile.template load<T, WM, 1, A_str_m, A_str_k>(As);
 
  449      simdgroup_barrier(mem_flags::mem_none);
 
  451      Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs);
 
  453      simdgroup_barrier(mem_flags::mem_none);
 
 
  467    for (
short i = 0; i < 
decltype(
Ctile)::kElemsPerTile; i++) {
 
  468      Ctile.elems()[i] = Epilogue::apply(
Ctile.elems()[i]);
 
  474    Ctile.template store<U, WM, WN>(D, ldd);
 
 
  481    for (
short i = 0; i < 
decltype(
Ctile)::kElemsPerTile; i++) {
 
  482      Ctile.elems()[i] = Epilogue::apply(
Ctile.elems()[i]);
 
  487    dst_tile_dims -= short2(
sn, 
sm);
 
  489    if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
 
  492    Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims);
 
 
  496  template <
typename UnaryEpilogue>
 
  500    for (
short i = 0; i < 
decltype(
Ctile)::kElemsPerTile; i++) {
 
  501      Ctile.elems()[i] = epilogue_op.apply(
Ctile.elems()[i]);
 
 
  506  template <
typename BinaryEpilogue>
 
  511      thread 
const BinaryEpilogue& epilogue_op) {
 
  513    C += (
sm)*ldc + (
sn)*fdc;
 
  517    for (
short i = 0; i < 
TM; i++) {
 
  519      for (
short j = 0; j < 
TN; j++) {
 
  521        thread 
auto& accum = 
Ctile.frag_at(i, j);
 
  526        for (
short k = 0; k < 
decltype(
Ctile)::kElemsPerFrag; k++) {
 
  527          accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
 
 
  534  template <
typename BinaryEpilogue>
 
  539      short2 dst_tile_dims,
 
  540      thread 
const BinaryEpilogue& epilogue_op) {
 
  542    C += (
sm)*ldc + (
sn)*fdc;
 
  543    dst_tile_dims -= short2(
sn, 
sm);
 
  545    if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
 
  550    for (
short i = 0; i < 
TM; i++) {
 
  552      for (
short j = 0; j < 
TN; j++) {
 
  554        thread 
auto& accum = 
Ctile.frag_at(i, j);
 
  557        constexpr short kelems = 
decltype(
Ctile)::kElemsPerFrag;
 
  560        U c_elems[kelems] = {0};
 
  563        for (
short k = 0; k < kelems; k++) {
 
  564          if ((j * 
TN_stride + k) < dst_tile_dims.x) {
 
  565            c_elems[k] = C[offset_c + k * fdc];
 
  571        for (
short k = 0; k < kelems; k++) {
 
  572          accum[k] = epilogue_op.apply(accum[k], c_elems[k]);
 
 
  585      thread 
const Epilogue& epilogue_op)
 const {
 
  587    C += (
sm)*ldc + (
sn)*fdc;
 
  590    constexpr short kelems = 
decltype(
Ctile)::kElemsPerFrag;
 
  594    for (
short i = 0; i < 
TM; i++) {
 
  596      for (
short j = 0; j < 
TN; j++) {
 
  598        thread 
const auto& accum = 
Ctile.frag_at(i, j);
 
  604        for (
short k = 0; k < kelems; k++) {
 
  605          D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
 
 
  617      short2 dst_tile_dims,
 
  618      thread 
const Epilogue& epilogue_op)
 const {
 
  620    C += (
sm)*ldc + (
sn)*fdc;
 
  622    dst_tile_dims -= short2(
sn, 
sm);
 
  624    if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
 
  627    constexpr short kelems = 
decltype(
Ctile)::kElemsPerFrag;
 
  630    for (
int i = 0; i < 
TM; i++) {
 
  633        for (
int j = 0; j < 
TN; j++) {
 
  635          thread 
const auto& accum = 
Ctile.frag_at(i, j);
 
  641          for (
short k = 0; k < kelems; k++) {
 
  642            if ((j * 
TN_stride + k) < dst_tile_dims.x) {
 
  644                  epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
 
 
 
METAL_FUNC void tile_matmad(thread MMATile< T, M, N > &D, thread MMATile< U, M, K > &A, thread MMATile< U, K, N > &B, thread MMATile< T, M, N > &C)
Definition mma.h:341
 
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
 
#define STEEL_CONST
Definition defines.h:3
 
static METAL_FUNC constexpr void mma(thread mat_type &D, thread mat_type &A, thread mat_type &B, thread mat_type &C)
Definition mma.h:164
 
static METAL_FUNC constexpr void store_safe(const thread frag_type &src, DstPtrType dst, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x=Int< 0 >{}, OffY off_y=Int< 0 >{})
Definition mma.h:122
 
metal::simdgroup_matrix< T, kFragRows, kFragCols > mat_type
Definition mma.h:46
 
static METAL_FUNC constexpr short2 get_coord(ushort simd_lane_id)
Definition mma.h:49
 
static METAL_FUNC constexpr void mma(thread frag_type &D, thread frag_type &A, thread frag_type &B, thread frag_type &C)
Definition mma.h:145
 
static METAL_FUNC constexpr void store(const thread frag_type &src, DstPtrType dst, StrX str_x, StrY str_y)
Definition mma.h:102
 
static METAL_FUNC constexpr void load(thread frag_type &dst, SrcPtrType src, StrX str_x, StrY str_y)
Definition mma.h:59
 
static METAL_FUNC constexpr void load_safe(thread frag_type &dst, SrcPtrType src, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x=Int< 0 >{}, OffY off_y=Int< 0 >{})
Definition mma.h:77
 
metal::vec< T, kElemsPerFrag > frag_type
Definition mma.h:47
 
METAL_FUNC void store_result(device U *D, const int ldd)
Definition mma.h:464
 
METAL_FUNC void store_result_safe(device U *D, const int ldd, short2 dst_tile_dims)
Definition mma.h:478
 
short As_offset
Definition mma.h:413
 
MMATile< AccumType, 1, TN, MMAFrag_acc_t > Btile
Definition mma.h:406
 
STEEL_CONST short A_str_k
Definition mma.h:394
 
STEEL_CONST short B_str_n
Definition mma.h:398
 
STEEL_CONST short TM_stride
Definition mma.h:383
 
METAL_FUNC void mma(const threadgroup T *As, const threadgroup T *Bs)
Definition mma.h:437
 
STEEL_CONST short TN
Definition mma.h:390
 
METAL_FUNC void store_result_safe(device U *D, const int ldd, const device U *C, const int ldc, const int fdc, short2 dst_tile_dims, thread const Epilogue &epilogue_op) const
Definition mma.h:611
 
METAL_FUNC void store_result(device U *D, const int ldd, const device U *C, const int ldc, const int fdc, thread const Epilogue &epilogue_op) const
Definition mma.h:579
 
MMATile< AccumType, TM, TN, MMAFrag_acc_t > Ctile
Definition mma.h:407
 
METAL_FUNC void apply_epilogue(const device U *C, const int ldc, const int fdc, thread const BinaryEpilogue &epilogue_op)
Definition mma.h:507
 
STEEL_CONST short TN_stride
Definition mma.h:385
 
STEEL_CONST short tile_stride_a
Definition mma.h:401
 
short Bs_offset
Definition mma.h:414
 
METAL_FUNC void apply_epilogue_safe(const device U *C, const int ldc, const int fdc, short2 dst_tile_dims, thread const BinaryEpilogue &epilogue_op)
Definition mma.h:535
 
METAL_FUNC BlockMMA(ushort simd_group_id, ushort simd_lane_id)
Definition mma.h:417
 
STEEL_CONST short B_str_k
Definition mma.h:397
 
short sm
Definition mma.h:410
 
STEEL_CONST short A_str_m
Definition mma.h:393
 
STEEL_CONST short TM
Definition mma.h:388
 
short sn
Definition mma.h:411
 
STEEL_CONST short tile_stride_b
Definition mma.h:402
 
STEEL_CONST short kFragSize
Definition mma.h:379
 
MMATile< AccumType, TM, 1, MMAFrag_acc_t > Atile
Definition mma.h:405
 
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue &epilogue_op)
Definition mma.h:497
 
METAL_FUNC constexpr thread frag_type & frag_at(const short i, const short j)
Definition mma.h:208
 
STEEL_CONST int kTileRows
Definition mma.h:185
 
MMAFrag_t::mat_type mat_type
Definition mma.h:194
 
METAL_FUNC void store(threadgroup U *dst) const
Definition mma.h:253
 
METAL_FUNC mat_type mat_at(const short i, const short j)
Definition mma.h:218
 
STEEL_CONST int kTileCols
Definition mma.h:186
 
METAL_FUNC void store_safe(device U *dst, const int ld, const short2 dst_tile_dims) const
Definition mma.h:321
 
STEEL_CONST int kFragRows
Definition mma.h:181
 
MMAFrag_t::frag_type frag_type
Definition mma.h:195
 
STEEL_CONST int kRows
Definition mma.h:188
 
METAL_FUNC void store(device U *dst, const int ld) const
Definition mma.h:285
 
T elem_type
Definition mma.h:180
 
METAL_FUNC thread elem_type * elems()
Definition mma.h:227
 
STEEL_CONST int kCols
Definition mma.h:189
 
STEEL_CONST int kElemsPerTile
Definition mma.h:192
 
METAL_FUNC void load_safe(const device U *src, const int ld, const short2 src_tile_dims)
Definition mma.h:301
 
METAL_FUNC MMATile() thread
Definition mma.h:199
 
METAL_FUNC void load(const threadgroup U *src)
Definition mma.h:236
 
METAL_FUNC constexpr void clear()
Definition mma.h:201
 
METAL_FUNC void load(const device U *src, const int ld)
Definition mma.h:270
 
MMAFrag_ MMAFrag_t
Definition mma.h:179
 
frag_type val_frags[kNumFrags]
Definition mma.h:197
 
STEEL_CONST int kFragCols
Definition mma.h:182
 
METAL_FUNC constexpr const thread frag_type & frag_at(const short i, const short j) const
Definition mma.h:212
 
METAL_FUNC const thread elem_type * elems() const
Definition mma.h:231
 
STEEL_CONST int kNumFrags
Definition mma.h:191
 
STEEL_CONST int kElemsPerFrag
Definition mma.h:183
 
Definition integral_constant.h:18