5#include <metal_simdgroup>
6#include <metal_simdgroup_matrix>
22template <
typename RInt,
typename CInt>
30template <
typename Shape,
typename Layout>
36template <
typename T,
int kFragRows_,
int kFragCols_>
40 "Only 8 x 8 fragment matrices are currently supported");
43 "Only 8 x 8 fragment matrices are currently supported");
58 "MMAFrag shape is not consistent with MMAFrag size");
60 typedef metal::simdgroup_matrix<T, kFragRows, kFragCols>
mat_type;
66 using dtype_mat_t =
typename metal::simdgroup_matrix<U, kFragRows, kFragCols>;
71 METAL_FUNC
static constexpr short2
get_coord(ushort simd_lane_id
72 [[thread_index_in_simdgroup]]) {
73 const short qid = simd_lane_id / 4;
74 const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);
75 const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
76 return short2{fn, fm};
79 template <
typename SrcPtrType,
typename StrX,
typename StrY>
80 METAL_FUNC
static constexpr void
86 dst[i *
kElemCols + j] =
static_cast<T
>(src[i * str_x + j * str_y]);
109 for (
short i = 0; i < kElemRows; i++) {
111 for (
short j = 0; j < kElemCols; j++) {
112 if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
113 dst[i * kElemCols + j] =
114 static_cast<T
>(src[(off_x + i) * str_x + (off_x + j) * str_y]);
116 dst[i * kElemCols + j] = T(0);
122 template <
typename DstPtrType,
typename StrX,
typename StrY>
123 METAL_FUNC
static constexpr void
131 dst[i * str_x + j * str_y] =
static_cast<U
>(src[i *
kElemCols + j]);
156 for (
short i = 0; i < kElemRows; i++) {
158 for (
short j = 0; j < kElemCols; j++) {
159 if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
160 dst[(off_x + i) * str_x + (off_y + j) * str_y] =
161 static_cast<U
>(src[i * kElemCols + j]);
167 template <
typename Atype,
typename Btype,
typename Ctype>
168 METAL_FUNC
static constexpr void mma(
182 mma(D_mat, A_mat, B_mat, C_mat);
184 D =
reinterpret_cast<thread
frag_type&
>(D_mat.thread_elements());
187 template <
typename Atype,
typename Btype,
typename Ctype>
188 METAL_FUNC
static constexpr void mma(
193 simdgroup_multiply_accumulate(D, A, B, C);
196 template <
typename Op>
199 thread T* reduced_vals) {
200 T thr_reduce = Op::apply(inp_vals.x, inp_vals.y);
203 qgr_reduce = Op::apply(thr_reduce, qgr_reduce);
206 sgr_reduce = Op::apply(qgr_reduce, sgr_reduce);
208 reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce);
211 template <
typename Op>
214 thread T* row_vals) {
220 Op::apply(inp_vals[i *
kElemCols + j], row_vals[i]);
230 class MMAFrag_ = BaseMMAFrag<T, 8, 8>>
270 const short j)
const {
278 val_mat.thread_elements()[ii] =
frag_at(i, j)[ii];
291 template <
typename Op>
298 frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
303 template <
typename Op>
310 frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
315 template <
typename U,
int w_x,
int w_y,
int str_x,
int str_y>
316 METAL_FUNC
void load(
const threadgroup U* src) {
332 template <
typename U,
int w_x,
int w_y,
int str_x,
int str_y>
333 METAL_FUNC
void store(threadgroup U* dst)
const {
349 template <
typename U,
int w_x,
int w_y>
350 METAL_FUNC
void load(
const device U* src,
const int ld) {
364 template <
typename U,
int w_x,
int w_y>
365 METAL_FUNC
void store(device U* dst,
const int ld)
const {
379 template <
typename U,
int w_x,
int w_y>
381 load_safe(
const device U* src,
const int ld,
const short2 src_tile_dims) {
386 MMAFrag_t::load_safe(
399 template <
typename U,
int w_x,
int w_y>
401 store_safe(device U* dst,
const int ld,
const short2 dst_tile_dims)
const {
406 MMAFrag_t::store_safe(
438 for (
short m = 0; m < M; ++m) {
440 for (
short n = 0; n < N; ++n) {
442 short n_serp = (m % 2) ? (N - 1 - n) : n;
445 for (
short k = 0; k < K; ++k) {
447 D.frag_at(m_serp, n_serp),
448 A.frag_at(m_serp, k),
449 B.frag_at(k, n_serp),
450 C.frag_at(m_serp, n_serp));
468 typename AccumType = float,
469 typename Epilogue = TransformNone<U, AccumType>>
511 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
512 ushort simd_lane_id [[thread_index_in_simdgroup]]) {
514 short tm =
kFragSize * (simd_group_id / WN);
515 short tn =
kFragSize * (simd_group_id % WN);
517 short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
530 METAL_FUNC
void mma(
const threadgroup T* As,
const threadgroup T* Bs) {
537 for (
short kk = 0; kk < BK; kk +=
kFragSize) {
538 simdgroup_barrier(mem_flags::mem_none);
540 Atile.template load<T, WM, 1, A_str_m, A_str_k>(As);
542 simdgroup_barrier(mem_flags::mem_none);
544 Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs);
546 simdgroup_barrier(mem_flags::mem_none);
560 for (
short i = 0; i <
decltype(
Ctile)::kElemsPerTile; i++) {
561 Ctile.elems()[i] = Epilogue::apply(
Ctile.elems()[i]);
567 Ctile.template store<U, WM, WN>(D, ldd);
574 for (
short i = 0; i <
decltype(
Ctile)::kElemsPerTile; i++) {
575 Ctile.elems()[i] = Epilogue::apply(
Ctile.elems()[i]);
580 dst_tile_dims -= short2(
sn,
sm);
582 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
585 Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims);
589 template <
typename UnaryEpilogue>
593 for (
short i = 0; i <
decltype(
Ctile)::kElemsPerTile; i++) {
594 Ctile.elems()[i] = epilogue_op.apply(
Ctile.elems()[i]);
599 template <
typename BinaryEpilogue>
604 thread
const BinaryEpilogue& epilogue_op) {
606 C += (
sm)*ldc + (
sn)*fdc;
610 for (
short i = 0; i <
TM; i++) {
612 for (
short j = 0; j <
TN; j++) {
614 thread
auto& accum =
Ctile.frag_at(i, j);
619 for (
short k = 0; k <
decltype(
Ctile)::kElemsPerFrag; k++) {
620 accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
627 template <
typename BinaryEpilogue>
632 short2 dst_tile_dims,
633 thread
const BinaryEpilogue& epilogue_op) {
635 C += (
sm)*ldc + (
sn)*fdc;
636 dst_tile_dims -= short2(
sn,
sm);
638 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
643 for (
short i = 0; i <
TM; i++) {
645 for (
short j = 0; j <
TN; j++) {
647 thread
auto& accum =
Ctile.frag_at(i, j);
650 constexpr short kelems =
decltype(
Ctile)::kElemsPerFrag;
653 U c_elems[kelems] = {0};
656 for (
short k = 0; k < kelems; k++) {
657 if ((j *
TN_stride + k) < dst_tile_dims.x) {
658 c_elems[k] = C[offset_c + k * fdc];
664 for (
short k = 0; k < kelems; k++) {
665 accum[k] = epilogue_op.apply(accum[k], c_elems[k]);
678 thread
const Epilogue& epilogue_op)
const {
680 C += (
sm)*ldc + (
sn)*fdc;
683 constexpr short kelems =
decltype(
Ctile)::kElemsPerFrag;
687 for (
short i = 0; i <
TM; i++) {
689 for (
short j = 0; j <
TN; j++) {
691 thread
const auto& accum =
Ctile.frag_at(i, j);
697 for (
short k = 0; k < kelems; k++) {
698 D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
710 short2 dst_tile_dims,
711 thread
const Epilogue& epilogue_op)
const {
713 C += (
sm)*ldc + (
sn)*fdc;
715 dst_tile_dims -= short2(
sn,
sm);
717 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
720 constexpr short kelems =
decltype(
Ctile)::kElemsPerFrag;
723 for (
int i = 0; i <
TM; i++) {
726 for (
int j = 0; j <
TN; j++) {
728 thread
const auto& accum =
Ctile.frag_at(i, j);
734 for (
short k = 0; k < kelems; k++) {
735 if ((j *
TN_stride + k) < dst_tile_dims.x) {
737 epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
METAL_FUNC void tile_matmad(thread MMATile< Dtype, M, N, MMAFragD > &D, thread MMATile< Atype, M, K, MMAFragA > &A, thread MMATile< Btype, K, N, MMAFragB > &B, thread MMATile< Ctype, M, N, MMAFragC > &C)
Definition mma.h:432
integral_constant< int, val > Int
Definition integral_constant.h:48
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
#define STEEL_CONST
Definition defines.h:3
static METAL_FUNC constexpr void store_safe(const thread frag_type &src, DstPtrType dst, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x=Int< 0 >{}, OffY off_y=Int< 0 >{})
Definition mma.h:144
STEEL_CONST int kFragCols
Definition mma.h:49
metal::simdgroup_matrix< T, kFragRows, kFragCols > mat_type
Definition mma.h:60
STEEL_CONST int kFragRows
Definition mma.h:48
static METAL_FUNC constexpr void row_bin_op(thread frag_type &inp_vals, thread T *row_vals)
Definition mma.h:212
STEEL_CONST int kElemsPerFrag
Definition mma.h:51
metal::vec< T, kElemRows > row_frag_type
Definition mma.h:62
static METAL_FUNC constexpr void row_reduce(thread const frag_type &inp_vals, thread T *reduced_vals)
Definition mma.h:197
typename metal::vec< U, kElemsPerFrag > dtype_frag_t
Definition mma.h:69
static METAL_FUNC constexpr short2 get_coord(ushort simd_lane_id)
Definition mma.h:71
STEEL_CONST int kElemRows
Definition mma.h:53
STEEL_CONST int kElemCols
Definition mma.h:54
metal::vec< T, kElemsPerFrag > frag_type
Definition mma.h:61
typename metal::simdgroup_matrix< U, kFragRows, kFragCols > dtype_mat_t
Definition mma.h:66
static METAL_FUNC constexpr void store(const thread frag_type &src, DstPtrType dst, StrX str_x, StrY str_y)
Definition mma.h:124
static METAL_FUNC constexpr void load(thread frag_type &dst, SrcPtrType src, StrX str_x, StrY str_y)
Definition mma.h:81
static METAL_FUNC constexpr void load_safe(thread frag_type &dst, SrcPtrType src, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x=Int< 0 >{}, OffY off_y=Int< 0 >{})
Definition mma.h:99
metal::vec< T, kElemCols > col_frag_type
Definition mma.h:63
static METAL_FUNC constexpr void mma(thread frag_type &D, thread dtype_frag_t< Atype > &A, thread dtype_frag_t< Btype > &B, thread dtype_frag_t< Ctype > &C)
Definition mma.h:168
static METAL_FUNC constexpr void mma(thread mat_type &D, thread dtype_mat_t< Atype > &A, thread dtype_mat_t< Btype > &B, thread dtype_mat_t< Ctype > &C)
Definition mma.h:188
METAL_FUNC void store_result(device U *D, const int ldd)
Definition mma.h:557
METAL_FUNC void store_result_safe(device U *D, const int ldd, short2 dst_tile_dims)
Definition mma.h:571
mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, transpose_a ? BM+tgp_padding_a :BK+tgp_padding_a, transpose_b ? BK+tgp_padding_b :BN+tgp_padding_b, AccumType, Epilogue >::As_offset short As_offset
Definition mma.h:506
mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, transpose_a ? BM+tgp_padding_a :BK+tgp_padding_a, transpose_b ? BK+tgp_padding_b :BN+tgp_padding_b, AccumType, Epilogue >::Ctile MMATile< AccumType, TM, TN, MMAFrag_acc_t > Ctile
Definition mma.h:500
mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, transpose_a ? BM+tgp_padding_a :BK+tgp_padding_a, transpose_b ? BK+tgp_padding_b :BN+tgp_padding_b, AccumType, Epilogue >::A_str_k STEEL_CONST short A_str_k
Definition mma.h:487
mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, transpose_a ? BM+tgp_padding_a :BK+tgp_padding_a, transpose_b ? BK+tgp_padding_b :BN+tgp_padding_b, AccumType, Epilogue >::Btile MMATile< AccumType, 1, TN, MMAFrag_acc_t > Btile
Definition mma.h:499
mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, transpose_a ? BM+tgp_padding_a :BK+tgp_padding_a, transpose_b ? BK+tgp_padding_b :BN+tgp_padding_b, AccumType, Epilogue >::Atile MMATile< AccumType, TM, 1, MMAFrag_acc_t > Atile
Definition mma.h:498
mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, transpose_a ? BM+tgp_padding_a :BK+tgp_padding_a, transpose_b ? BK+tgp_padding_b :BN+tgp_padding_b, AccumType, Epilogue >::B_str_n STEEL_CONST short B_str_n
Definition mma.h:491
mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, transpose_a ? BM+tgp_padding_a :BK+tgp_padding_a, transpose_b ? BK+tgp_padding_b :BN+tgp_padding_b, AccumType, Epilogue >::TM_stride STEEL_CONST short TM_stride
Definition mma.h:476
METAL_FUNC void mma(const threadgroup T *As, const threadgroup T *Bs)
Definition mma.h:530
mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, transpose_a ? BM+tgp_padding_a :BK+tgp_padding_a, transpose_b ? BK+tgp_padding_b :BN+tgp_padding_b, AccumType, Epilogue >::TN STEEL_CONST short TN
Definition mma.h:483
METAL_FUNC void store_result_safe(device U *D, const int ldd, const device U *C, const int ldc, const int fdc, short2 dst_tile_dims, thread const Epilogue &epilogue_op) const
Definition mma.h:704
METAL_FUNC void store_result(device U *D, const int ldd, const device U *C, const int ldc, const int fdc, thread const Epilogue &epilogue_op) const
Definition mma.h:672
METAL_FUNC void apply_epilogue(const device U *C, const int ldc, const int fdc, thread const BinaryEpilogue &epilogue_op)
Definition mma.h:600
mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, transpose_a ? BM+tgp_padding_a :BK+tgp_padding_a, transpose_b ? BK+tgp_padding_b :BN+tgp_padding_b, AccumType, Epilogue >::TN_stride STEEL_CONST short TN_stride
Definition mma.h:478
mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, transpose_a ? BM+tgp_padding_a :BK+tgp_padding_a, transpose_b ? BK+tgp_padding_b :BN+tgp_padding_b, AccumType, Epilogue >::tile_stride_a STEEL_CONST short tile_stride_a
Definition mma.h:494
mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, transpose_a ? BM+tgp_padding_a :BK+tgp_padding_a, transpose_b ? BK+tgp_padding_b :BN+tgp_padding_b, AccumType, Epilogue >::Bs_offset short Bs_offset
Definition mma.h:507
METAL_FUNC void apply_epilogue_safe(const device U *C, const int ldc, const int fdc, short2 dst_tile_dims, thread const BinaryEpilogue &epilogue_op)
Definition mma.h:628
METAL_FUNC BlockMMA(ushort simd_group_id, ushort simd_lane_id)
Definition mma.h:510
mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, transpose_a ? BM+tgp_padding_a :BK+tgp_padding_a, transpose_b ? BK+tgp_padding_b :BN+tgp_padding_b, AccumType, Epilogue >::B_str_k STEEL_CONST short B_str_k
Definition mma.h:490
mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, transpose_a ? BM+tgp_padding_a :BK+tgp_padding_a, transpose_b ? BK+tgp_padding_b :BN+tgp_padding_b, AccumType, Epilogue >::sm short sm
Definition mma.h:503
mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, transpose_a ? BM+tgp_padding_a :BK+tgp_padding_a, transpose_b ? BK+tgp_padding_b :BN+tgp_padding_b, AccumType, Epilogue >::A_str_m STEEL_CONST short A_str_m
Definition mma.h:486
mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, transpose_a ? BM+tgp_padding_a :BK+tgp_padding_a, transpose_b ? BK+tgp_padding_b :BN+tgp_padding_b, AccumType, Epilogue >::TM STEEL_CONST short TM
Definition mma.h:481
mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, transpose_a ? BM+tgp_padding_a :BK+tgp_padding_a, transpose_b ? BK+tgp_padding_b :BN+tgp_padding_b, AccumType, Epilogue >::sn short sn
Definition mma.h:504
mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, transpose_a ? BM+tgp_padding_a :BK+tgp_padding_a, transpose_b ? BK+tgp_padding_b :BN+tgp_padding_b, AccumType, Epilogue >::MMAFrag_acc_t BaseMMAFrag< AccumType, kFragSize, kFragSize > MMAFrag_acc_t
Definition mma.h:473
mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, transpose_a ? BM+tgp_padding_a :BK+tgp_padding_a, transpose_b ? BK+tgp_padding_b :BN+tgp_padding_b, AccumType, Epilogue >::tile_stride_b STEEL_CONST short tile_stride_b
Definition mma.h:495
mlx::steel::BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, transpose_a ? BM+tgp_padding_a :BK+tgp_padding_a, transpose_b ? BK+tgp_padding_b :BN+tgp_padding_b, AccumType, Epilogue >::kFragSize STEEL_CONST short kFragSize
Definition mma.h:472
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue &epilogue_op)
Definition mma.h:590
Shape shape
Definition mma.h:32
Layout layout
Definition mma.h:33
METAL_FUNC constexpr thread frag_type & frag_at(const short i, const short j)
Definition mma.h:264
STEEL_CONST int kTileRows
Definition mma.h:238
STEEL_CONST int kColsPerThread
Definition mma.h:248
MMAFrag_t::mat_type mat_type
Definition mma.h:250
METAL_FUNC void store(threadgroup U *dst) const
Definition mma.h:333
METAL_FUNC mat_type mat_at(const short i, const short j)
Definition mma.h:274
METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread])
Definition mma.h:304
STEEL_CONST int kTileCols
Definition mma.h:239
METAL_FUNC void store_safe(device U *dst, const int ld, const short2 dst_tile_dims) const
Definition mma.h:401
STEEL_CONST int kFragRows
Definition mma.h:234
STEEL_CONST int kRowsPerThread
Definition mma.h:247
STEEL_CONST int kRows
Definition mma.h:241
frag_type val_frags[kNumFrags]
Definition mma.h:253
METAL_FUNC void store(device U *dst, const int ld) const
Definition mma.h:365
T elem_type
Definition mma.h:233
METAL_FUNC thread elem_type * elems()
Definition mma.h:283
STEEL_CONST int kCols
Definition mma.h:242
STEEL_CONST int kElemsPerTile
Definition mma.h:245
METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const
Definition mma.h:292
METAL_FUNC void load_safe(const device U *src, const int ld, const short2 src_tile_dims)
Definition mma.h:381
METAL_FUNC MMATile() thread
Definition mma.h:255
METAL_FUNC void load(const threadgroup U *src)
Definition mma.h:316
METAL_FUNC constexpr void clear()
Definition mma.h:257
METAL_FUNC void load(const device U *src, const int ld)
Definition mma.h:350
MMAFrag_t::frag_type frag_type
Definition mma.h:251
MMAFrag_ MMAFrag_t
Definition mma.h:232
STEEL_CONST int kFragCols
Definition mma.h:235
METAL_FUNC constexpr const thread frag_type & frag_at(const short i, const short j) const
Definition mma.h:268
METAL_FUNC const thread elem_type * elems() const
Definition mma.h:287
STEEL_CONST int kNumFrags
Definition mma.h:244
STEEL_CONST int kElemsPerFrag
Definition mma.h:236
Shape2D(RInt r_, CInt c_)
Definition mma.h:27
RInt r
Definition mma.h:24
CInt c
Definition mma.h:25