mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 18:11:17 +08:00
[WIP]: Reductions and min working aligned kernel at headdim = 64
This commit is contained in:
parent
168a3a464a
commit
d927ed9e32
@ -29,6 +29,48 @@ struct TransformScale {
|
||||
}
|
||||
};
|
||||
|
||||
struct MaxOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return metal::max(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
struct SumOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return x + y;
|
||||
}
|
||||
};
|
||||
|
||||
struct MulOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return x * y;
|
||||
}
|
||||
};
|
||||
|
||||
struct SubOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return x - y;
|
||||
}
|
||||
};
|
||||
|
||||
struct ExpSubOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return fast::exp(x - y);
|
||||
}
|
||||
};
|
||||
|
||||
struct DivOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return x / y;
|
||||
}
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
template <
|
||||
typename T,
|
||||
@ -130,11 +172,12 @@ template <
|
||||
BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0,
|
||||
"Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.");
|
||||
|
||||
constexpr int TQ = BQ / (kNWarps * kFragSize); // Q seq frags per warp
|
||||
constexpr int TK =
|
||||
BK / kFragSize; // KV sequence frags (all warps load the same frags)
|
||||
constexpr int TD =
|
||||
BD / kFragSize; // HeadDim frags (all warps load the same frags)
|
||||
// Q seq frags per warp
|
||||
constexpr int TQ = BQ / (kNWarps * kFragSize);
|
||||
// KV sequence frags (all warps load the same frags)
|
||||
constexpr int TK = BK / kFragSize;
|
||||
// HeadDim frags (all warps load the same frags)
|
||||
constexpr int TD = BD / kFragSize;
|
||||
|
||||
static_assert(TQ == 1, "Check TQ");
|
||||
|
||||
@ -163,6 +206,16 @@ template <
|
||||
loader_q.load_unsafe();
|
||||
loader_q.apply_inplace_op(ts);
|
||||
|
||||
constexpr int kRowsPT = decltype(Stile)::kRowsPerThread;
|
||||
|
||||
AccumType max_score[kRowsPT];
|
||||
AccumType sum_score[kRowsPT] = {0};
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
max_score[i] = Limits<AccumType>::min;
|
||||
}
|
||||
|
||||
for (int kb = 0; kb < params->NK; kb++) {
|
||||
// Load Q and K blocks and apply scale
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
@ -193,6 +246,43 @@ template <
|
||||
|
||||
// Do softmax
|
||||
|
||||
// Row max
|
||||
AccumType new_max[kRowsPT];
|
||||
AccumType factor[kRowsPT];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
new_max[i] = max_score[i];
|
||||
}
|
||||
|
||||
Stile.template row_reduce<MaxOp>(new_max);
|
||||
|
||||
// exp(Si - rowmax(Si))
|
||||
Stile.template row_bin_op<ExpSubOp>(new_max);
|
||||
|
||||
// Factor exp(rowmax(Si) - rowmax(Si-1))
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
factor[i] = fast::exp(max_score[i] - new_max[i]);
|
||||
}
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
max_score[i] = new_max[i];
|
||||
}
|
||||
|
||||
// Row Sum
|
||||
AccumType sum_score_tmp[kRowsPT] = {0};
|
||||
Stile.template row_reduce<SumOp>(sum_score_tmp);
|
||||
|
||||
// Update norm
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i];
|
||||
}
|
||||
|
||||
// Update O
|
||||
Otile.template row_bin_op<MulOp>(factor);
|
||||
|
||||
// Do O = S @ V
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
Vtile.template load<T, 1, 1, LDV_tgp, 1>(&Vs[Vs_offset]);
|
||||
@ -207,6 +297,7 @@ template <
|
||||
loader_v.next();
|
||||
}
|
||||
|
||||
Otile.template row_bin_op<DivOp>(sum_score);
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Store results
|
||||
|
@ -59,6 +59,8 @@ struct BaseMMAFrag<T, 8, 8> {
|
||||
|
||||
typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;
|
||||
typedef metal::vec<T, kElemsPerFrag> frag_type;
|
||||
typedef metal::vec<T, kElemRows> row_frag_type;
|
||||
typedef metal::vec<T, kElemCols> col_frag_type;
|
||||
|
||||
METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
|
||||
[[thread_index_in_simdgroup]]) {
|
||||
@ -182,6 +184,35 @@ struct BaseMMAFrag<T, 8, 8> {
|
||||
thread mat_type& C) {
|
||||
simdgroup_multiply_accumulate(D, A, B, C);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
METAL_FUNC static constexpr void row_reduce(
|
||||
thread const frag_type& inp_vals,
|
||||
thread T* reduced_vals) {
|
||||
T thr_reduce = Op::apply(inp_vals.x, inp_vals.y);
|
||||
|
||||
T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1));
|
||||
qgr_reduce = Op::apply(thr_reduce, qgr_reduce);
|
||||
|
||||
T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8));
|
||||
sgr_reduce = Op::apply(qgr_reduce, sgr_reduce);
|
||||
|
||||
reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
METAL_FUNC static constexpr void row_bin_op(
|
||||
thread frag_type& inp_vals,
|
||||
thread T* row_vals) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
inp_vals[i * kElemCols + j] =
|
||||
Op::apply(inp_vals[i * kElemCols + j], row_vals[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
@ -205,6 +236,9 @@ struct MMATile {
|
||||
STEEL_CONST int kNumFrags = kTileRows * kTileCols;
|
||||
STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag;
|
||||
|
||||
STEEL_CONST int kRowsPerThread = kTileRows * MMAFrag_t::kElemRows;
|
||||
STEEL_CONST int kColsPerThread = kTileCols * MMAFrag_t::kElemCols;
|
||||
|
||||
typedef typename MMAFrag_t::mat_type mat_type;
|
||||
typedef typename MMAFrag_t::frag_type frag_type;
|
||||
|
||||
@ -246,6 +280,30 @@ struct MMATile {
|
||||
return reinterpret_cast<const thread elem_type*>(val_frags);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::template row_reduce<Op>(
|
||||
frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread]) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::template row_bin_op<Op>(
|
||||
frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y, int str_x, int str_y>
|
||||
METAL_FUNC void load(const threadgroup U* src) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
|
Loading…
Reference in New Issue
Block a user