[WIP]: Reductions and min working aligned kernel at headdim = 64

This commit is contained in:
Jagrit Digani 2024-11-19 14:10:05 -08:00
parent 168a3a464a
commit d927ed9e32
2 changed files with 154 additions and 5 deletions

View File

@ -29,6 +29,48 @@ struct TransformScale {
} }
}; };
struct MaxOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return metal::max(x, y);
}
};
struct SumOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x + y;
}
};
struct MulOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x * y;
}
};
struct SubOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x - y;
}
};
struct ExpSubOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return fast::exp(x - y);
}
};
struct DivOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x / y;
}
};
// clang-format off // clang-format off
template < template <
typename T, typename T,
@ -130,11 +172,12 @@ template <
BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0, BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0,
"Each simdgroup must host atleast 1 simdgroup matrix along Q sequence."); "Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.");
constexpr int TQ = BQ / (kNWarps * kFragSize); // Q seq frags per warp // Q seq frags per warp
constexpr int TK = constexpr int TQ = BQ / (kNWarps * kFragSize);
BK / kFragSize; // KV sequence frags (all warps load the same frags) // KV sequence frags (all warps load the same frags)
constexpr int TD = constexpr int TK = BK / kFragSize;
BD / kFragSize; // HeadDim frags (all warps load the same frags) // HeadDim frags (all warps load the same frags)
constexpr int TD = BD / kFragSize;
static_assert(TQ == 1, "Check TQ"); static_assert(TQ == 1, "Check TQ");
@ -163,6 +206,16 @@ template <
loader_q.load_unsafe(); loader_q.load_unsafe();
loader_q.apply_inplace_op(ts); loader_q.apply_inplace_op(ts);
constexpr int kRowsPT = decltype(Stile)::kRowsPerThread;
AccumType max_score[kRowsPT];
AccumType sum_score[kRowsPT] = {0};
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = Limits<AccumType>::min;
}
for (int kb = 0; kb < params->NK; kb++) { for (int kb = 0; kb < params->NK; kb++) {
// Load Q and K blocks and apply scale // Load Q and K blocks and apply scale
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
@ -193,6 +246,43 @@ template <
// Do softmax // Do softmax
// Row max
AccumType new_max[kRowsPT];
AccumType factor[kRowsPT];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
new_max[i] = max_score[i];
}
Stile.template row_reduce<MaxOp>(new_max);
// exp(Si - rowmax(Si))
Stile.template row_bin_op<ExpSubOp>(new_max);
// Factor exp(rowmax(Si) - rowmax(Si-1))
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
factor[i] = fast::exp(max_score[i] - new_max[i]);
}
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = new_max[i];
}
// Row Sum
AccumType sum_score_tmp[kRowsPT] = {0};
Stile.template row_reduce<SumOp>(sum_score_tmp);
// Update norm
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i];
}
// Update O
Otile.template row_bin_op<MulOp>(factor);
// Do O = S @ V // Do O = S @ V
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
Vtile.template load<T, 1, 1, LDV_tgp, 1>(&Vs[Vs_offset]); Vtile.template load<T, 1, 1, LDV_tgp, 1>(&Vs[Vs_offset]);
@ -207,6 +297,7 @@ template <
loader_v.next(); loader_v.next();
} }
Otile.template row_bin_op<DivOp>(sum_score);
threadgroup_barrier(mem_flags::mem_none); threadgroup_barrier(mem_flags::mem_none);
// Store results // Store results

View File

@ -59,6 +59,8 @@ struct BaseMMAFrag<T, 8, 8> {
typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type; typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;
typedef metal::vec<T, kElemsPerFrag> frag_type; typedef metal::vec<T, kElemsPerFrag> frag_type;
typedef metal::vec<T, kElemRows> row_frag_type;
typedef metal::vec<T, kElemCols> col_frag_type;
METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
[[thread_index_in_simdgroup]]) { [[thread_index_in_simdgroup]]) {
@ -182,6 +184,35 @@ struct BaseMMAFrag<T, 8, 8> {
thread mat_type& C) { thread mat_type& C) {
simdgroup_multiply_accumulate(D, A, B, C); simdgroup_multiply_accumulate(D, A, B, C);
} }
template <typename Op>
METAL_FUNC static constexpr void row_reduce(
thread const frag_type& inp_vals,
thread T* reduced_vals) {
T thr_reduce = Op::apply(inp_vals.x, inp_vals.y);
T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1));
qgr_reduce = Op::apply(thr_reduce, qgr_reduce);
T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8));
sgr_reduce = Op::apply(qgr_reduce, sgr_reduce);
reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce);
}
template <typename Op>
METAL_FUNC static constexpr void row_bin_op(
thread frag_type& inp_vals,
thread T* row_vals) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
inp_vals[i * kElemCols + j] =
Op::apply(inp_vals[i * kElemCols + j], row_vals[i]);
}
}
}
}; };
template < template <
@ -205,6 +236,9 @@ struct MMATile {
STEEL_CONST int kNumFrags = kTileRows * kTileCols; STEEL_CONST int kNumFrags = kTileRows * kTileCols;
STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag; STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag;
STEEL_CONST int kRowsPerThread = kTileRows * MMAFrag_t::kElemRows;
STEEL_CONST int kColsPerThread = kTileCols * MMAFrag_t::kElemCols;
typedef typename MMAFrag_t::mat_type mat_type; typedef typename MMAFrag_t::mat_type mat_type;
typedef typename MMAFrag_t::frag_type frag_type; typedef typename MMAFrag_t::frag_type frag_type;
@ -246,6 +280,30 @@ struct MMATile {
return reinterpret_cast<const thread elem_type*>(val_frags); return reinterpret_cast<const thread elem_type*>(val_frags);
} }
template <typename Op>
METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::template row_reduce<Op>(
frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
}
}
}
template <typename Op>
METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread]) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::template row_bin_op<Op>(
frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
}
}
}
template <typename U, int w_x, int w_y, int str_x, int str_y> template <typename U, int w_x, int w_y, int str_x, int str_y>
METAL_FUNC void load(const threadgroup U* src) { METAL_FUNC void load(const threadgroup U* src) {
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL