5#include <metal_simdgroup>
6#include <metal_simdgroup_matrix>
33 typename AccumType = float,
34 typename Epilogue = TransformNone<U, AccumType>>
60 simdgroup_matrix<AccumType, 8, 8>
Asimd[
TM];
61 simdgroup_matrix<AccumType, 8, 8>
Bsimd[
TN];
63 simdgroup_matrix<AccumType, 8, 8>(0)};
77 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
78 ushort simd_lane_id [[thread_index_in_simdgroup]])
79 :
tm(8 * (simd_group_id / WN)),
tn(8 * (simd_group_id % WN)) {
81 short qid = simd_lane_id / 4;
82 sm = (qid & 4) + (simd_lane_id / 2) % 4;
83 sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
87 transpose_a ? ((
sn)*lda_tgp + (
tm +
sm)) : ((
sn) + (
tm +
sm) * lda_tgp);
89 transpose_b ? ((
tn +
sn) * ldb_tgp + (
sm)) : ((
sm)*ldb_tgp + (
tn +
sn));
93 METAL_FUNC
void mma(
const threadgroup T* As,
const threadgroup T* Bs) {
100 for (
short kk = 0; kk < BK; kk += 8) {
101 simdgroup_barrier(mem_flags::mem_none);
105 for (
short i = 0; i <
TM; i++) {
106 Asimd[i].thread_elements()[0] =
108 Asimd[i].thread_elements()[1] =
112 simdgroup_barrier(mem_flags::mem_none);
116 for (
short j = 0; j <
TN; j++) {
117 Bsimd[j].thread_elements()[0] =
119 Bsimd[j].thread_elements()[1] =
123 simdgroup_barrier(mem_flags::mem_none);
127 for (
short i = 0; i <
TM; i++) {
129 for (
short j = 0; j <
TN; j++) {
130 short j_serp = (i % 2) ? (
TN - 1 - j) : j;
132 simdgroup_multiply_accumulate(
153 for (
short i = 0; i <
TM; i++) {
155 for (
short j = 0; j <
TN; j++) {
157 thread
const auto& accum =
results[i *
TN + j].thread_elements();
161 U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
165 D[offset + 1] = outs[1];
174 dst_tile_dims -= short2(
tn +
sn,
sm +
tm);
176 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
180 for (
int i = 0; i <
TM; i++) {
183 for (
int j = 0; j <
TN; j++) {
185 thread
const auto& accum =
results[i *
TN + j].thread_elements();
190 D[offset] = Epilogue::apply(accum[0]);
193 if (j *
TN_stride + 1 < dst_tile_dims.x) {
194 D[offset + 1] = Epilogue::apply(accum[1]);
202 template <
typename UnaryEpilogue>
206 for (
short i = 0; i <
TM; i++) {
208 for (
short j = 0; j <
TN; j++) {
210 thread
auto& accum =
results[i *
TN + j].thread_elements();
213 accum[0] = epilogue_op.apply(accum[0]);
214 accum[1] = epilogue_op.apply(accum[1]);
220 template <
typename BinaryEpilogue>
225 thread
const BinaryEpilogue& epilogue_op) {
227 C += (
sm +
tm) * ldc + (
tn +
sn) * fdc;
231 for (
short i = 0; i <
TM; i++) {
233 for (
short j = 0; j <
TN; j++) {
235 thread
auto& accum =
results[i *
TN + j].thread_elements();
239 accum[0] = epilogue_op.apply(accum[0], C[offset_c]);
240 accum[1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
246 template <
typename BinaryEpilogue>
251 short2 dst_tile_dims,
252 thread
const BinaryEpilogue& epilogue_op) {
254 C += (
sm +
tm) * ldc + (
tn +
sn) * fdc;
255 dst_tile_dims -= short2(
tn +
sn,
sm +
tm);
257 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
262 for (
short i = 0; i <
TM; i++) {
264 for (
short j = 0; j <
TN; j++) {
266 thread
auto& accum =
results[i *
TN + j].thread_elements();
272 if ((j *
TN_stride + 1) < dst_tile_dims.x) {
273 c_elems[0] = C[offset_c];
274 c_elems[1] = C[offset_c + fdc];
275 }
else if ((j *
TN_stride) < dst_tile_dims.x) {
276 c_elems[0] = C[offset_c];
280 accum[0] = epilogue_op.apply(accum[0], c_elems[0]);
281 accum[1] = epilogue_op.apply(accum[1], c_elems[1]);
293 thread
const Epilogue& epilogue_op)
const {
295 C += (
sm +
tm) * ldc + (
tn +
sn) * fdc;
300 for (
short i = 0; i <
TM; i++) {
302 for (
short j = 0; j <
TN; j++) {
304 thread
const auto& accum =
results[i *
TN + j].thread_elements();
310 epilogue_op.apply(accum[0], C[offset_c]),
311 epilogue_op.apply(accum[1], C[offset_c + fdc])};
314 D[offset_d] = outs[0];
315 D[offset_d + 1] = outs[1];
326 short2 dst_tile_dims,
327 thread
const Epilogue& epilogue_op)
const {
329 C += (
sm +
tm) * ldc + (
tn +
sn) * fdc;
331 dst_tile_dims -= short2(
tn +
sn,
sm +
tm);
333 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
337 for (
int i = 0; i <
TM; i++) {
340 for (
int j = 0; j <
TN; j++) {
342 thread
const auto& accum =
results[i *
TN + j].thread_elements();
348 D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
351 if (j *
TN_stride + 1 < dst_tile_dims.x) {
352 D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
#define STEEL_CONST
Definition defines.h:3
short As_offset
Definition mma.h:72
STEEL_CONST short jump_b
Definition mma.h:54
simdgroup_matrix< AccumType, 8, 8 > Bsimd[TN]
Definition mma.h:61
STEEL_CONST short TM_stride
Definition mma.h:37
METAL_FUNC void mma(const threadgroup T *As, const threadgroup T *Bs)
Definition mma.h:93
simdgroup_matrix< AccumType, 8, 8 > results[TM *TN]
Definition mma.h:62
STEEL_CONST short TN
Definition mma.h:44
METAL_FUNC void store_result_safe(device U *D, const int ldd, const device U *C, const int ldc, const int fdc, short2 dst_tile_dims, thread const Epilogue &epilogue_op) const
Definition mma.h:320
METAL_FUNC void store_result(device U *D, const int ldd, const device U *C, const int ldc, const int fdc, thread const Epilogue &epilogue_op) const
Definition mma.h:287
METAL_FUNC void apply_epilogue(const device U *C, const int ldc, const int fdc, thread const BinaryEpilogue &epilogue_op)
Definition mma.h:221
METAL_FUNC void store_result(device U *D, const int ldd) const
Definition mma.h:147
STEEL_CONST short TN_stride
Definition mma.h:39
STEEL_CONST short tile_stride_a
Definition mma.h:56
simdgroup_matrix< AccumType, 8, 8 > Asimd[TM]
Definition mma.h:60
short Bs_offset
Definition mma.h:73
METAL_FUNC void apply_epilogue_safe(const device U *C, const int ldc, const int fdc, short2 dst_tile_dims, thread const BinaryEpilogue &epilogue_op)
Definition mma.h:247
METAL_FUNC BlockMMA(ushort simd_group_id, ushort simd_lane_id)
Definition mma.h:76
short sm
Definition mma.h:69
STEEL_CONST short simd_stride_a
Definition mma.h:47
const short tm
Definition mma.h:66
STEEL_CONST short TM
Definition mma.h:42
const short tn
Definition mma.h:67
STEEL_CONST short jump_a
Definition mma.h:53
short sn
Definition mma.h:70
STEEL_CONST short tile_stride_b
Definition mma.h:57
STEEL_CONST short simd_stride_b
Definition mma.h:49
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue &epilogue_op)
Definition mma.h:203
METAL_FUNC void store_result_safe(device U *D, const int ldd, short2 dst_tile_dims) const
Definition mma.h:171