5#include <metal_simdgroup>
6#include <metal_simdgroup_matrix>
33 typename AccumType = float,
34 typename Epilogue = TransformNone<U, AccumType>>
60 simdgroup_matrix<AccumType, 8, 8>
Asimd[
TM];
61 simdgroup_matrix<AccumType, 8, 8>
Bsimd[
TN];
63 simdgroup_matrix<AccumType, 8, 8>(0)};
77 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
78 ushort simd_lane_id [[thread_index_in_simdgroup]])
79 :
tm(8 * (simd_group_id / WN)),
tn(8 * (simd_group_id % WN)) {
81 short qid = simd_lane_id / 4;
82 sm = (qid & 4) + (simd_lane_id / 2) % 4;
83 sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
87 transpose_a ? ((
sn)*lda_tgp + (
tm +
sm)) : ((
sn) + (
tm +
sm) * lda_tgp);
89 transpose_b ? ((
tn +
sn) * ldb_tgp + (
sm)) : ((
sm)*ldb_tgp + (
tn +
sn));
93 METAL_FUNC
void mma(
const threadgroup T* As,
const threadgroup T* Bs) {
100 for (
short kk = 0; kk < BK; kk += 8) {
101 simdgroup_barrier(mem_flags::mem_none);
105 for (
short i = 0; i <
TM; i++) {
106 Asimd[i].thread_elements()[0] =
108 Asimd[i].thread_elements()[1] =
112 simdgroup_barrier(mem_flags::mem_none);
116 for (
short j = 0; j <
TN; j++) {
117 Bsimd[j].thread_elements()[0] =
119 Bsimd[j].thread_elements()[1] =
123 simdgroup_barrier(mem_flags::mem_none);
127 for (
short i = 0; i <
TM; i++) {
129 for (
short j = 0; j <
TN; j++) {
130 short j_serp = (i % 2) ? (
TN - 1 - j) : j;
132 simdgroup_multiply_accumulate(
153 for (
short i = 0; i <
TM; i++) {
155 for (
short j = 0; j <
TN; j++) {
157 thread
const auto& accum =
results[i *
TN + j].thread_elements();
161 U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
165 D[offset + 1] = outs[1];
174 dst_tile_dims -= short2(
tn +
sn,
sm +
tm);
176 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
180 for (
int i = 0; i <
TM; i++) {
183 for (
int j = 0; j <
TN; j++) {
185 thread
const auto& accum =
results[i *
TN + j].thread_elements();
190 D[offset] = Epilogue::apply(accum[0]);
193 if (j *
TN_stride + 1 < dst_tile_dims.x) {
194 D[offset + 1] = Epilogue::apply(accum[1]);
208 thread
const Epilogue& epilogue_op)
const {
210 C += (
sm +
tm) * ldc + (
tn +
sn) * fdc;
215 for (
short i = 0; i <
TM; i++) {
217 for (
short j = 0; j <
TN; j++) {
219 thread
const auto& accum =
results[i *
TN + j].thread_elements();
225 epilogue_op.apply(accum[0], C[offset_c]),
226 epilogue_op.apply(accum[1], C[offset_c + fdc])};
229 D[offset_d] = outs[0];
230 D[offset_d + 1] = outs[1];
241 short2 dst_tile_dims,
242 thread
const Epilogue& epilogue_op)
const {
244 C += (
sm +
tm) * ldc + (
tn +
sn) * fdc;
246 dst_tile_dims -= short2(
tn +
sn,
sm +
tm);
248 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
252 for (
int i = 0; i <
TM; i++) {
255 for (
int j = 0; j <
TN; j++) {
257 thread
const auto& accum =
results[i *
TN + j].thread_elements();
263 D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
266 if (j *
TN_stride + 1 < dst_tile_dims.x) {
267 D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
short As_offset
Definition mma.h:72
STEEL_CONST short jump_b
Definition mma.h:54
simdgroup_matrix< AccumType, 8, 8 > Bsimd[TN]
Definition mma.h:61
STEEL_CONST short TM_stride
Definition mma.h:37
METAL_FUNC void mma(const threadgroup T *As, const threadgroup T *Bs)
Definition mma.h:93
simdgroup_matrix< AccumType, 8, 8 > results[TM *TN]
Definition mma.h:62
STEEL_CONST short TN
Definition mma.h:44
METAL_FUNC void store_result_safe(device U *D, const int ldd, const device U *C, const int ldc, const int fdc, short2 dst_tile_dims, thread const Epilogue &epilogue_op) const
Definition mma.h:235
METAL_FUNC void store_result(device U *D, const int ldd, const device U *C, const int ldc, const int fdc, thread const Epilogue &epilogue_op) const
Definition mma.h:202
METAL_FUNC void store_result(device U *D, const int ldd) const
Definition mma.h:147
STEEL_CONST short TN_stride
Definition mma.h:39
STEEL_CONST short tile_stride_a
Definition mma.h:56
simdgroup_matrix< AccumType, 8, 8 > Asimd[TM]
Definition mma.h:60
short Bs_offset
Definition mma.h:73
METAL_FUNC BlockMMA(ushort simd_group_id, ushort simd_lane_id)
Definition mma.h:76
short sm
Definition mma.h:69
STEEL_CONST short simd_stride_a
Definition mma.h:47
const short tm
Definition mma.h:66
STEEL_CONST short TM
Definition mma.h:42
const short tn
Definition mma.h:67
STEEL_CONST short jump_a
Definition mma.h:53
short sn
Definition mma.h:70
STEEL_CONST short tile_stride_b
Definition mma.h:57
STEEL_CONST short simd_stride_b
Definition mma.h:49
METAL_FUNC void store_result_safe(device U *D, const int ldd, short2 dst_tile_dims) const
Definition mma.h:171