mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00

* Share KV smem * Fix bfloat error * Unroll O = S @ V loop * Perf upgrade * Remove commented out function * Add -Wno-c++17-extensions flag to metal flags * Add -Wno-c++17-extensions flag to metal extension flags
748 lines
21 KiB
C++
748 lines
21 KiB
C++
// Copyright © 2024 Apple Inc.
|
|
|
|
#pragma once
|
|
|
|
#include <metal_simdgroup>
|
|
#include <metal_simdgroup_matrix>
|
|
#include <metal_stdlib>
|
|
|
|
#include "mlx/backend/metal/kernels/steel/attn/transforms.h"
|
|
#include "mlx/backend/metal/kernels/steel/defines.h"
|
|
#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h"
|
|
|
|
using namespace metal;
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// MMA helper
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace mlx {
|
|
namespace steel {
|
|
|
|
template <typename RInt, typename CInt>
|
|
struct Shape2D {
|
|
RInt r;
|
|
CInt c;
|
|
|
|
Shape2D(RInt r_, CInt c_) : r(r_), c(c_) {}
|
|
};
|
|
|
|
template <typename Shape, typename Layout>
|
|
struct Layout2D {
|
|
Shape shape;
|
|
Layout layout;
|
|
};
|
|
|
|
template <typename T, int kFragRows_, int kFragCols_>
|
|
struct BaseMMAFrag {
|
|
static_assert(
|
|
kFragRows_ == 8,
|
|
"Only 8 x 8 fragment matrices are currently supported");
|
|
static_assert(
|
|
kFragCols_ == 8,
|
|
"Only 8 x 8 fragment matrices are currently supported");
|
|
};
|
|
|
|
template <typename T>
|
|
struct BaseMMAFrag<T, 8, 8> {
|
|
STEEL_CONST int kFragRows = 8;
|
|
STEEL_CONST int kFragCols = 8;
|
|
|
|
STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32;
|
|
|
|
STEEL_CONST int kElemRows = 1;
|
|
STEEL_CONST int kElemCols = 2;
|
|
|
|
static_assert(
|
|
kElemRows * kElemCols == kElemsPerFrag,
|
|
"MMAFrag shape is not consistent with MMAFrag size");
|
|
|
|
typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;
|
|
typedef metal::vec<T, kElemsPerFrag> frag_type;
|
|
typedef metal::vec<T, kElemRows> row_frag_type;
|
|
typedef metal::vec<T, kElemCols> col_frag_type;
|
|
|
|
template <typename U>
|
|
using dtype_mat_t = typename metal::simdgroup_matrix<U, kFragRows, kFragCols>;
|
|
|
|
template <typename U>
|
|
using dtype_frag_t = typename metal::vec<U, kElemsPerFrag>;
|
|
|
|
METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
|
|
[[thread_index_in_simdgroup]]) {
|
|
const short qid = simd_lane_id / 4;
|
|
const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);
|
|
const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
|
return short2{fn, fm};
|
|
}
|
|
|
|
template <typename SrcPtrType, typename StrX, typename StrY>
|
|
METAL_FUNC static constexpr void
|
|
load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short i = 0; i < kElemRows; i++) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short j = 0; j < kElemCols; j++) {
|
|
dst[i * kElemCols + j] = static_cast<T>(src[i * str_x + j * str_y]);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <
|
|
typename SrcPtrType,
|
|
typename StrX,
|
|
typename StrY,
|
|
typename LimX,
|
|
typename LimY,
|
|
typename OffX,
|
|
typename OffY>
|
|
METAL_FUNC static constexpr void load_safe(
|
|
thread frag_type& dst,
|
|
SrcPtrType src,
|
|
StrX str_x,
|
|
StrY str_y,
|
|
LimX lim_x,
|
|
LimY lim_y,
|
|
OffX off_x = Int<0>{},
|
|
OffY off_y = Int<0>{}) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short i = 0; i < kElemRows; i++) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short j = 0; j < kElemCols; j++) {
|
|
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
|
|
dst[i * kElemCols + j] =
|
|
static_cast<T>(src[(off_x + i) * str_x + (off_x + j) * str_y]);
|
|
} else {
|
|
dst[i * kElemCols + j] = T(0);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename DstPtrType, typename StrX, typename StrY>
|
|
METAL_FUNC static constexpr void
|
|
store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) {
|
|
using U = pointer_element_t<DstPtrType>;
|
|
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short i = 0; i < kElemRows; i++) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short j = 0; j < kElemCols; j++) {
|
|
dst[i * str_x + j * str_y] = static_cast<U>(src[i * kElemCols + j]);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <
|
|
typename DstPtrType,
|
|
typename StrX,
|
|
typename StrY,
|
|
typename LimX,
|
|
typename LimY,
|
|
typename OffX,
|
|
typename OffY>
|
|
METAL_FUNC static constexpr void store_safe(
|
|
const thread frag_type& src,
|
|
DstPtrType dst,
|
|
StrX str_x,
|
|
StrY str_y,
|
|
LimX lim_x,
|
|
LimY lim_y,
|
|
OffX off_x = Int<0>{},
|
|
OffY off_y = Int<0>{}) {
|
|
using U = pointer_element_t<DstPtrType>;
|
|
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short i = 0; i < kElemRows; i++) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short j = 0; j < kElemCols; j++) {
|
|
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
|
|
dst[(off_x + i) * str_x + (off_y + j) * str_y] =
|
|
static_cast<U>(src[i * kElemCols + j]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename Atype, typename Btype, typename Ctype>
|
|
METAL_FUNC static constexpr void mma(
|
|
thread frag_type& D,
|
|
thread dtype_frag_t<Atype>& A,
|
|
thread dtype_frag_t<Btype>& B,
|
|
thread dtype_frag_t<Ctype>& C) {
|
|
mat_type D_mat;
|
|
dtype_mat_t<Atype> A_mat;
|
|
dtype_mat_t<Btype> B_mat;
|
|
dtype_mat_t<Ctype> C_mat;
|
|
|
|
reinterpret_cast<thread dtype_frag_t<Atype>&>(A_mat.thread_elements()) = A;
|
|
reinterpret_cast<thread dtype_frag_t<Btype>&>(B_mat.thread_elements()) = B;
|
|
reinterpret_cast<thread dtype_frag_t<Ctype>&>(C_mat.thread_elements()) = C;
|
|
|
|
mma(D_mat, A_mat, B_mat, C_mat);
|
|
|
|
D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements());
|
|
}
|
|
|
|
template <typename Atype, typename Btype, typename Ctype>
|
|
METAL_FUNC static constexpr void mma(
|
|
thread mat_type& D,
|
|
thread dtype_mat_t<Atype>& A,
|
|
thread dtype_mat_t<Btype>& B,
|
|
thread dtype_mat_t<Ctype>& C) {
|
|
simdgroup_multiply_accumulate(D, A, B, C);
|
|
}
|
|
|
|
template <typename Op>
|
|
METAL_FUNC static constexpr void row_reduce(
|
|
thread const frag_type& inp_vals,
|
|
thread T* reduced_vals) {
|
|
T thr_reduce = Op::apply(inp_vals.x, inp_vals.y);
|
|
|
|
T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1));
|
|
qgr_reduce = Op::apply(thr_reduce, qgr_reduce);
|
|
|
|
T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8));
|
|
sgr_reduce = Op::apply(qgr_reduce, sgr_reduce);
|
|
|
|
reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce);
|
|
}
|
|
|
|
template <typename Op>
|
|
METAL_FUNC static constexpr void row_bin_op(
|
|
thread frag_type& inp_vals,
|
|
thread T* row_vals) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short i = 0; i < kElemRows; i++) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short j = 0; j < kElemCols; j++) {
|
|
inp_vals[i * kElemCols + j] =
|
|
Op::apply(inp_vals[i * kElemCols + j], row_vals[i]);
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
template <
|
|
typename T,
|
|
int kTileRows_,
|
|
int kTileCols_,
|
|
class MMAFrag_ = BaseMMAFrag<T, 8, 8>>
|
|
struct MMATile {
|
|
using MMAFrag_t = MMAFrag_;
|
|
using elem_type = T;
|
|
STEEL_CONST int kFragRows = MMAFrag_t::kFragRows;
|
|
STEEL_CONST int kFragCols = MMAFrag_t::kFragCols;
|
|
STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag;
|
|
|
|
STEEL_CONST int kTileRows = kTileRows_;
|
|
STEEL_CONST int kTileCols = kTileCols_;
|
|
|
|
STEEL_CONST int kRows = kTileRows * kFragRows;
|
|
STEEL_CONST int kCols = kTileCols * kFragCols;
|
|
|
|
STEEL_CONST int kNumFrags = kTileRows * kTileCols;
|
|
STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag;
|
|
|
|
STEEL_CONST int kRowsPerThread = kTileRows * MMAFrag_t::kElemRows;
|
|
STEEL_CONST int kColsPerThread = kTileCols * MMAFrag_t::kElemCols;
|
|
|
|
typedef typename MMAFrag_t::mat_type mat_type;
|
|
typedef typename MMAFrag_t::frag_type frag_type;
|
|
|
|
frag_type val_frags[kNumFrags]; // = {frag_type(0)};
|
|
|
|
METAL_FUNC MMATile() thread {}
|
|
|
|
METAL_FUNC constexpr void clear() {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short i = 0; i < kNumFrags; ++i) {
|
|
val_frags[i] = frag_type(0);
|
|
}
|
|
}
|
|
|
|
METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {
|
|
return val_frags[i * kTileCols + j];
|
|
}
|
|
|
|
METAL_FUNC constexpr const thread frag_type& frag_at(
|
|
const short i,
|
|
const short j) const {
|
|
return val_frags[i * kTileCols + j];
|
|
}
|
|
|
|
METAL_FUNC mat_type mat_at(const short i, const short j) {
|
|
mat_type val_mat;
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short ii = 0; ii < kElemsPerFrag; ++ii) {
|
|
val_mat.thread_elements()[ii] = frag_at(i, j)[ii];
|
|
}
|
|
return val_mat;
|
|
}
|
|
|
|
METAL_FUNC thread elem_type* elems() {
|
|
return reinterpret_cast<thread elem_type*>(val_frags);
|
|
}
|
|
|
|
METAL_FUNC const thread elem_type* elems() const {
|
|
return reinterpret_cast<const thread elem_type*>(val_frags);
|
|
}
|
|
|
|
template <typename Op>
|
|
METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short i = 0; i < kTileRows; ++i) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short j = 0; j < kTileCols; ++j) {
|
|
MMAFrag_t::template row_reduce<Op>(
|
|
frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename Op>
|
|
METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread]) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short i = 0; i < kTileRows; ++i) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short j = 0; j < kTileCols; ++j) {
|
|
MMAFrag_t::template row_bin_op<Op>(
|
|
frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename U, int w_x, int w_y, int str_x, int str_y>
|
|
METAL_FUNC void load(const threadgroup U* src) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short i = 0; i < kTileRows; ++i) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short j = 0; j < kTileCols; ++j) {
|
|
MMAFrag_t::load(
|
|
frag_at(i, j),
|
|
&(
|
|
src[(i * kFragRows) * w_x * str_x +
|
|
(j * kFragCols) * w_y * str_y]),
|
|
Int<str_x>{},
|
|
Int<str_y>{});
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename U, int w_x, int w_y, int str_x, int str_y>
|
|
METAL_FUNC void store(threadgroup U* dst) const {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short i = 0; i < kTileRows; ++i) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short j = 0; j < kTileCols; ++j) {
|
|
MMAFrag_t::store(
|
|
frag_at(i, j),
|
|
&(
|
|
dst[(i * kFragRows) * w_x * str_x +
|
|
(j * kFragCols) * w_y * str_y]),
|
|
Int<str_x>{},
|
|
Int<str_y>{});
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename U, int w_x, int w_y>
|
|
METAL_FUNC void load(const device U* src, const int ld) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short i = 0; i < kTileRows; ++i) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short j = 0; j < kTileCols; ++j) {
|
|
MMAFrag_t::load(
|
|
frag_at(i, j),
|
|
&(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
|
|
ld,
|
|
Int<1>{});
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename U, int w_x, int w_y>
|
|
METAL_FUNC void store(device U* dst, const int ld) const {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short i = 0; i < kTileRows; ++i) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short j = 0; j < kTileCols; ++j) {
|
|
MMAFrag_t::store(
|
|
frag_at(i, j),
|
|
&(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
|
|
ld,
|
|
Int<1>{});
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename U, int w_x, int w_y>
|
|
METAL_FUNC void
|
|
load_safe(const device U* src, const int ld, const short2 src_tile_dims) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (int i = 0; i < kTileRows; ++i) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (int j = 0; j < kTileCols; ++j) {
|
|
MMAFrag_t::load_safe(
|
|
frag_at(i, j),
|
|
src,
|
|
ld,
|
|
Int<1>{},
|
|
src_tile_dims.y,
|
|
src_tile_dims.x,
|
|
(i * kFragRows) * w_x,
|
|
(j * kFragCols) * w_y);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename U, int w_x, int w_y>
|
|
METAL_FUNC void
|
|
store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (int i = 0; i < kTileRows; ++i) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (int j = 0; j < kTileCols; ++j) {
|
|
MMAFrag_t::store_safe(
|
|
frag_at(i, j),
|
|
dst,
|
|
ld,
|
|
Int<1>{},
|
|
dst_tile_dims.y,
|
|
dst_tile_dims.x,
|
|
(i * kFragRows) * w_x,
|
|
(j * kFragCols) * w_y);
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
template <
|
|
typename Dtype,
|
|
typename Atype,
|
|
typename Btype,
|
|
typename Ctype,
|
|
int M,
|
|
int N,
|
|
int K,
|
|
class MMAFragD,
|
|
class MMAFragA,
|
|
class MMAFragB,
|
|
class MMAFragC>
|
|
METAL_FUNC void tile_matmad(
|
|
thread MMATile<Dtype, M, N, MMAFragD>& D,
|
|
thread MMATile<Atype, M, K, MMAFragA>& A,
|
|
thread MMATile<Btype, K, N, MMAFragB>& B,
|
|
thread MMATile<Ctype, M, N, MMAFragC>& C) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short m = 0; m < M; ++m) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short n = 0; n < N; ++n) {
|
|
short m_serp = m; //(n % 2) ? (M - 1 - m) : m;
|
|
short n_serp = (m % 2) ? (N - 1 - n) : n;
|
|
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short k = 0; k < K; ++k) {
|
|
MMAFragD::mma(
|
|
D.frag_at(m_serp, n_serp),
|
|
A.frag_at(m_serp, k),
|
|
B.frag_at(k, n_serp),
|
|
C.frag_at(m_serp, n_serp));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <
|
|
typename T,
|
|
typename U,
|
|
int BM,
|
|
int BN,
|
|
int BK,
|
|
int WM,
|
|
int WN,
|
|
bool transpose_a,
|
|
bool transpose_b,
|
|
short lda_tgp,
|
|
short ldb_tgp,
|
|
typename AccumType = float,
|
|
typename Epilogue = TransformNone<U, AccumType>>
|
|
struct BlockMMA {
|
|
// MMAFrag size
|
|
STEEL_CONST short kFragSize = 8;
|
|
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
|
|
|
|
// Warp tile simdgroup matrix strides along M
|
|
STEEL_CONST short TM_stride = kFragSize * WM;
|
|
// Warp tile simdgroup matrix strides along M
|
|
STEEL_CONST short TN_stride = kFragSize * WN;
|
|
|
|
// Warp tile size along M
|
|
STEEL_CONST short TM = BM / TM_stride;
|
|
// Warp tile size along N
|
|
STEEL_CONST short TN = BN / TN_stride;
|
|
|
|
// Threadgroup A strides
|
|
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
|
|
STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
|
|
|
|
// Threadgroup B strides
|
|
STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
|
|
STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
|
|
|
|
// Threadgroup strides along K
|
|
STEEL_CONST short tile_stride_a = kFragSize * A_str_k;
|
|
STEEL_CONST short tile_stride_b = kFragSize * B_str_k;
|
|
|
|
// Simdgroup matrices
|
|
MMATile<AccumType, TM, 1, MMAFrag_acc_t> Atile;
|
|
MMATile<AccumType, 1, TN, MMAFrag_acc_t> Btile;
|
|
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile;
|
|
|
|
// Offsets within threadgroup
|
|
short sm;
|
|
short sn;
|
|
|
|
short As_offset;
|
|
short Bs_offset;
|
|
|
|
/* Constructor */
|
|
METAL_FUNC BlockMMA(
|
|
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
ushort simd_lane_id [[thread_index_in_simdgroup]]) {
|
|
// Determine thread position in simdgroup matrix
|
|
short tm = kFragSize * (simd_group_id / WN);
|
|
short tn = kFragSize * (simd_group_id % WN);
|
|
|
|
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
|
|
sm = simd_coord.y;
|
|
sn = simd_coord.x;
|
|
|
|
// Determine thread and simdgroup offset
|
|
As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K
|
|
Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N
|
|
|
|
sm += tm;
|
|
sn += tn;
|
|
}
|
|
|
|
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
|
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
|
// Adjust for simdgroup and thread location
|
|
As += As_offset;
|
|
Bs += Bs_offset;
|
|
|
|
// Iterate over BK in blocks of kFragSize
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short kk = 0; kk < BK; kk += kFragSize) {
|
|
simdgroup_barrier(mem_flags::mem_none);
|
|
|
|
Atile.template load<T, WM, 1, A_str_m, A_str_k>(As);
|
|
|
|
simdgroup_barrier(mem_flags::mem_none);
|
|
|
|
Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs);
|
|
|
|
simdgroup_barrier(mem_flags::mem_none);
|
|
|
|
tile_matmad(Ctile, Atile, Btile, Ctile);
|
|
|
|
// Progress to next simdgroup tile
|
|
As += tile_stride_a;
|
|
Bs += tile_stride_b;
|
|
}
|
|
}
|
|
|
|
/* Store results from simdgroup_matrix results into device memory */
|
|
METAL_FUNC void store_result(device U* D, const int ldd) {
|
|
// Apply epilogue
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
|
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
|
|
}
|
|
|
|
// Adjust for simdgroup and thread location
|
|
D += sm * ldd + sn;
|
|
|
|
Ctile.template store<U, WM, WN>(D, ldd);
|
|
}
|
|
|
|
METAL_FUNC void
|
|
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
|
|
// Apply epilogue
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
|
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
|
|
}
|
|
|
|
// Adjust for simdgroup and thread location
|
|
D += sm * ldd + sn;
|
|
dst_tile_dims -= short2(sn, sm);
|
|
|
|
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
|
return;
|
|
|
|
Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims);
|
|
}
|
|
|
|
/* Apply epilogue */
|
|
template <typename UnaryEpilogue>
|
|
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
|
|
// Loop over all simdgroup tiles
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
|
Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]);
|
|
}
|
|
}
|
|
|
|
/* Apply epilogue */
|
|
template <typename BinaryEpilogue>
|
|
METAL_FUNC void apply_epilogue(
|
|
const device U* C,
|
|
const int ldc,
|
|
const int fdc,
|
|
thread const BinaryEpilogue& epilogue_op) {
|
|
// Adjust for simdgroup and thread location
|
|
C += (sm)*ldc + (sn)*fdc;
|
|
|
|
// Loop over all simdgroup tiles
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short i = 0; i < TM; i++) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short j = 0; j < TN; j++) {
|
|
// Get accumulated result and associated offset in C
|
|
thread auto& accum = Ctile.frag_at(i, j);
|
|
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
|
|
|
// Apply epilogue
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) {
|
|
accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/* Apply epilogue */
|
|
template <typename BinaryEpilogue>
|
|
METAL_FUNC void apply_epilogue_safe(
|
|
const device U* C,
|
|
const int ldc,
|
|
const int fdc,
|
|
short2 dst_tile_dims,
|
|
thread const BinaryEpilogue& epilogue_op) {
|
|
// Adjust for simdgroup and thread location
|
|
C += (sm)*ldc + (sn)*fdc;
|
|
dst_tile_dims -= short2(sn, sm);
|
|
|
|
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
|
return;
|
|
|
|
// Loop over all simdgroup tiles
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short i = 0; i < TM; i++) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short j = 0; j < TN; j++) {
|
|
// Get accumulated result and associated offset in C
|
|
thread auto& accum = Ctile.frag_at(i, j);
|
|
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
|
|
|
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
|
|
|
|
// Read C
|
|
U c_elems[kelems] = {0};
|
|
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short k = 0; k < kelems; k++) {
|
|
if ((j * TN_stride + k) < dst_tile_dims.x) {
|
|
c_elems[k] = C[offset_c + k * fdc];
|
|
}
|
|
}
|
|
|
|
// Apply epilogue
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short k = 0; k < kelems; k++) {
|
|
accum[k] = epilogue_op.apply(accum[k], c_elems[k]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/* Store results from simdgroup_matrix results into device memory */
|
|
METAL_FUNC void store_result(
|
|
device U* D,
|
|
const int ldd,
|
|
const device U* C,
|
|
const int ldc,
|
|
const int fdc,
|
|
thread const Epilogue& epilogue_op) const {
|
|
// Adjust for simdgroup and thread location
|
|
C += (sm)*ldc + (sn)*fdc;
|
|
D += (sm)*ldd + sn;
|
|
|
|
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
|
|
|
|
// Loop over all simdgroup tiles
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short i = 0; i < TM; i++) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short j = 0; j < TN; j++) {
|
|
// Get accumulated result and associated offset in C
|
|
thread const auto& accum = Ctile.frag_at(i, j);
|
|
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
|
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
|
|
|
// Apply epilogue
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short k = 0; k < kelems; k++) {
|
|
D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
METAL_FUNC void store_result_safe(
|
|
device U* D,
|
|
const int ldd,
|
|
const device U* C,
|
|
const int ldc,
|
|
const int fdc,
|
|
short2 dst_tile_dims,
|
|
thread const Epilogue& epilogue_op) const {
|
|
// Adjust for simdgroup and thread location
|
|
C += (sm)*ldc + (sn)*fdc;
|
|
D += (sm)*ldd + sn;
|
|
dst_tile_dims -= short2(sn, sm);
|
|
|
|
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
|
return;
|
|
|
|
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
|
|
|
|
STEEL_PRAGMA_UNROLL
|
|
for (int i = 0; i < TM; i++) {
|
|
if (i * TM_stride < dst_tile_dims.y) {
|
|
STEEL_PRAGMA_UNROLL
|
|
for (int j = 0; j < TN; j++) {
|
|
// Get accumulated result and associated offset in C
|
|
thread const auto& accum = Ctile.frag_at(i, j);
|
|
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
|
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
|
|
|
// Apply epilogue
|
|
STEEL_PRAGMA_UNROLL
|
|
for (short k = 0; k < kelems; k++) {
|
|
if ((j * TN_stride + k) < dst_tile_dims.x) {
|
|
D[offset_d + k] =
|
|
epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
} // namespace steel
|
|
} // namespace mlx
|