5#include <metal_simdgroup> 
    6#include <metal_simdgroup_matrix> 
   33    typename AccumType = float,
 
   34    typename Epilogue = TransformNone<U, AccumType>>
 
   60  simdgroup_matrix<AccumType, 8, 8> 
Asimd[
TM];
 
   61  simdgroup_matrix<AccumType, 8, 8> 
Bsimd[
TN];
 
   63      simdgroup_matrix<AccumType, 8, 8>(0)};
 
 
   77      ushort simd_group_id [[simdgroup_index_in_threadgroup]],
 
   78      ushort simd_lane_id [[thread_index_in_simdgroup]])
 
   79      : 
tm(8 * (simd_group_id / WN)), 
tn(8 * (simd_group_id % WN)) {
 
   81    short qid = simd_lane_id / 4;
 
   82    sm = (qid & 4) + (simd_lane_id / 2) % 4;
 
   83    sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
 
   87        transpose_a ? ((
sn)*lda_tgp + (
tm + 
sm)) : ((
sn) + (
tm + 
sm) * lda_tgp);
 
   89        transpose_b ? ((
tn + 
sn) * ldb_tgp + (
sm)) : ((
sm)*ldb_tgp + (
tn + 
sn));
 
 
   93  METAL_FUNC 
void mma(
const threadgroup T* As, 
const threadgroup T* Bs) {
 
  100    for (
short kk = 0; kk < BK; kk += 8) {
 
  101      simdgroup_barrier(mem_flags::mem_none);
 
  105      for (
short i = 0; i < 
TM; i++) {
 
  106        Asimd[i].thread_elements()[0] =
 
  108        Asimd[i].thread_elements()[1] =
 
  112      simdgroup_barrier(mem_flags::mem_none);
 
  116      for (
short j = 0; j < 
TN; j++) {
 
  117        Bsimd[j].thread_elements()[0] =
 
  119        Bsimd[j].thread_elements()[1] =
 
  123      simdgroup_barrier(mem_flags::mem_none);
 
  127      for (
short i = 0; i < 
TM; i++) {
 
  129        for (
short j = 0; j < 
TN; j++) {
 
  130          short j_serp = (i % 2) ? (
TN - 1 - j) : j;
 
  132          simdgroup_multiply_accumulate(
 
 
  153    for (
short i = 0; i < 
TM; i++) {
 
  155      for (
short j = 0; j < 
TN; j++) {
 
  157        thread 
const auto& accum = 
results[i * 
TN + j].thread_elements();
 
  161        U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
 
  165        D[offset + 1] = outs[1];
 
 
  174    dst_tile_dims -= short2(
tn + 
sn, 
sm + 
tm);
 
  176    if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
 
  180    for (
int i = 0; i < 
TM; i++) {
 
  183        for (
int j = 0; j < 
TN; j++) {
 
  185          thread 
const auto& accum = 
results[i * 
TN + j].thread_elements();
 
  190            D[offset] = Epilogue::apply(accum[0]);
 
  193          if (j * 
TN_stride + 1 < dst_tile_dims.x) {
 
  194            D[offset + 1] = Epilogue::apply(accum[1]);
 
 
  208      thread 
const Epilogue& epilogue_op)
 const {
 
  210    C += (
sm + 
tm) * ldc + (
tn + 
sn) * fdc;
 
  215    for (
short i = 0; i < 
TM; i++) {
 
  217      for (
short j = 0; j < 
TN; j++) {
 
  219        thread 
const auto& accum = 
results[i * 
TN + j].thread_elements();
 
  225            epilogue_op.apply(accum[0], C[offset_c]),
 
  226            epilogue_op.apply(accum[1], C[offset_c + fdc])};
 
  229        D[offset_d] = outs[0];
 
  230        D[offset_d + 1] = outs[1];
 
 
  241      short2 dst_tile_dims,
 
  242      thread 
const Epilogue& epilogue_op)
 const {
 
  244    C += (
sm + 
tm) * ldc + (
tn + 
sn) * fdc;
 
  246    dst_tile_dims -= short2(
tn + 
sn, 
sm + 
tm);
 
  248    if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
 
  252    for (
int i = 0; i < 
TM; i++) {
 
  255        for (
int j = 0; j < 
TN; j++) {
 
  257          thread 
const auto& accum = 
results[i * 
TN + j].thread_elements();
 
  263            D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
 
  266          if (j * 
TN_stride + 1 < dst_tile_dims.x) {
 
  267            D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
 
 
 
short As_offset
Definition mma.h:72
 
STEEL_CONST short jump_b
Definition mma.h:54
 
simdgroup_matrix< AccumType, 8, 8 > Bsimd[TN]
Definition mma.h:61
 
STEEL_CONST short TM_stride
Definition mma.h:37
 
METAL_FUNC void mma(const threadgroup T *As, const threadgroup T *Bs)
Definition mma.h:93
 
simdgroup_matrix< AccumType, 8, 8 > results[TM *TN]
Definition mma.h:62
 
STEEL_CONST short TN
Definition mma.h:44
 
METAL_FUNC void store_result_safe(device U *D, const int ldd, const device U *C, const int ldc, const int fdc, short2 dst_tile_dims, thread const Epilogue &epilogue_op) const
Definition mma.h:235
 
METAL_FUNC void store_result(device U *D, const int ldd, const device U *C, const int ldc, const int fdc, thread const Epilogue &epilogue_op) const
Definition mma.h:202
 
METAL_FUNC void store_result(device U *D, const int ldd) const
Definition mma.h:147
 
STEEL_CONST short TN_stride
Definition mma.h:39
 
STEEL_CONST short tile_stride_a
Definition mma.h:56
 
simdgroup_matrix< AccumType, 8, 8 > Asimd[TM]
Definition mma.h:60
 
short Bs_offset
Definition mma.h:73
 
METAL_FUNC BlockMMA(ushort simd_group_id, ushort simd_lane_id)
Definition mma.h:76
 
short sm
Definition mma.h:69
 
STEEL_CONST short simd_stride_a
Definition mma.h:47
 
const short tm
Definition mma.h:66
 
STEEL_CONST short TM
Definition mma.h:42
 
const short tn
Definition mma.h:67
 
STEEL_CONST short jump_a
Definition mma.h:53
 
short sn
Definition mma.h:70
 
STEEL_CONST short tile_stride_b
Definition mma.h:57
 
STEEL_CONST short simd_stride_b
Definition mma.h:49
 
METAL_FUNC void store_result_safe(device U *D, const int ldd, short2 dst_tile_dims) const
Definition mma.h:171