mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 18:11:17 +08:00
[WIP]: Reductions and min working aligned kernel at headdim = 64
This commit is contained in:
parent
168a3a464a
commit
d927ed9e32
@ -29,6 +29,48 @@ struct TransformScale {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct MaxOp {
|
||||||
|
template <typename T>
|
||||||
|
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||||
|
return metal::max(x, y);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct SumOp {
|
||||||
|
template <typename T>
|
||||||
|
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||||
|
return x + y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct MulOp {
|
||||||
|
template <typename T>
|
||||||
|
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||||
|
return x * y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct SubOp {
|
||||||
|
template <typename T>
|
||||||
|
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||||
|
return x - y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ExpSubOp {
|
||||||
|
template <typename T>
|
||||||
|
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||||
|
return fast::exp(x - y);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct DivOp {
|
||||||
|
template <typename T>
|
||||||
|
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||||
|
return x / y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
@ -130,11 +172,12 @@ template <
|
|||||||
BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0,
|
BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0,
|
||||||
"Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.");
|
"Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.");
|
||||||
|
|
||||||
constexpr int TQ = BQ / (kNWarps * kFragSize); // Q seq frags per warp
|
// Q seq frags per warp
|
||||||
constexpr int TK =
|
constexpr int TQ = BQ / (kNWarps * kFragSize);
|
||||||
BK / kFragSize; // KV sequence frags (all warps load the same frags)
|
// KV sequence frags (all warps load the same frags)
|
||||||
constexpr int TD =
|
constexpr int TK = BK / kFragSize;
|
||||||
BD / kFragSize; // HeadDim frags (all warps load the same frags)
|
// HeadDim frags (all warps load the same frags)
|
||||||
|
constexpr int TD = BD / kFragSize;
|
||||||
|
|
||||||
static_assert(TQ == 1, "Check TQ");
|
static_assert(TQ == 1, "Check TQ");
|
||||||
|
|
||||||
@ -163,6 +206,16 @@ template <
|
|||||||
loader_q.load_unsafe();
|
loader_q.load_unsafe();
|
||||||
loader_q.apply_inplace_op(ts);
|
loader_q.apply_inplace_op(ts);
|
||||||
|
|
||||||
|
constexpr int kRowsPT = decltype(Stile)::kRowsPerThread;
|
||||||
|
|
||||||
|
AccumType max_score[kRowsPT];
|
||||||
|
AccumType sum_score[kRowsPT] = {0};
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < kRowsPT; ++i) {
|
||||||
|
max_score[i] = Limits<AccumType>::min;
|
||||||
|
}
|
||||||
|
|
||||||
for (int kb = 0; kb < params->NK; kb++) {
|
for (int kb = 0; kb < params->NK; kb++) {
|
||||||
// Load Q and K blocks and apply scale
|
// Load Q and K blocks and apply scale
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
@ -193,6 +246,43 @@ template <
|
|||||||
|
|
||||||
// Do softmax
|
// Do softmax
|
||||||
|
|
||||||
|
// Row max
|
||||||
|
AccumType new_max[kRowsPT];
|
||||||
|
AccumType factor[kRowsPT];
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < kRowsPT; ++i) {
|
||||||
|
new_max[i] = max_score[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
Stile.template row_reduce<MaxOp>(new_max);
|
||||||
|
|
||||||
|
// exp(Si - rowmax(Si))
|
||||||
|
Stile.template row_bin_op<ExpSubOp>(new_max);
|
||||||
|
|
||||||
|
// Factor exp(rowmax(Si) - rowmax(Si-1))
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < kRowsPT; ++i) {
|
||||||
|
factor[i] = fast::exp(max_score[i] - new_max[i]);
|
||||||
|
}
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < kRowsPT; ++i) {
|
||||||
|
max_score[i] = new_max[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Row Sum
|
||||||
|
AccumType sum_score_tmp[kRowsPT] = {0};
|
||||||
|
Stile.template row_reduce<SumOp>(sum_score_tmp);
|
||||||
|
|
||||||
|
// Update norm
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < kRowsPT; ++i) {
|
||||||
|
sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update O
|
||||||
|
Otile.template row_bin_op<MulOp>(factor);
|
||||||
|
|
||||||
// Do O = S @ V
|
// Do O = S @ V
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
Vtile.template load<T, 1, 1, LDV_tgp, 1>(&Vs[Vs_offset]);
|
Vtile.template load<T, 1, 1, LDV_tgp, 1>(&Vs[Vs_offset]);
|
||||||
@ -207,6 +297,7 @@ template <
|
|||||||
loader_v.next();
|
loader_v.next();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Otile.template row_bin_op<DivOp>(sum_score);
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
// Store results
|
// Store results
|
||||||
|
@ -59,6 +59,8 @@ struct BaseMMAFrag<T, 8, 8> {
|
|||||||
|
|
||||||
typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;
|
typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;
|
||||||
typedef metal::vec<T, kElemsPerFrag> frag_type;
|
typedef metal::vec<T, kElemsPerFrag> frag_type;
|
||||||
|
typedef metal::vec<T, kElemRows> row_frag_type;
|
||||||
|
typedef metal::vec<T, kElemCols> col_frag_type;
|
||||||
|
|
||||||
METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
|
METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
|
||||||
[[thread_index_in_simdgroup]]) {
|
[[thread_index_in_simdgroup]]) {
|
||||||
@ -182,6 +184,35 @@ struct BaseMMAFrag<T, 8, 8> {
|
|||||||
thread mat_type& C) {
|
thread mat_type& C) {
|
||||||
simdgroup_multiply_accumulate(D, A, B, C);
|
simdgroup_multiply_accumulate(D, A, B, C);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
METAL_FUNC static constexpr void row_reduce(
|
||||||
|
thread const frag_type& inp_vals,
|
||||||
|
thread T* reduced_vals) {
|
||||||
|
T thr_reduce = Op::apply(inp_vals.x, inp_vals.y);
|
||||||
|
|
||||||
|
T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1));
|
||||||
|
qgr_reduce = Op::apply(thr_reduce, qgr_reduce);
|
||||||
|
|
||||||
|
T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8));
|
||||||
|
sgr_reduce = Op::apply(qgr_reduce, sgr_reduce);
|
||||||
|
|
||||||
|
reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
METAL_FUNC static constexpr void row_bin_op(
|
||||||
|
thread frag_type& inp_vals,
|
||||||
|
thread T* row_vals) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < kElemRows; i++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < kElemCols; j++) {
|
||||||
|
inp_vals[i * kElemCols + j] =
|
||||||
|
Op::apply(inp_vals[i * kElemCols + j], row_vals[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <
|
template <
|
||||||
@ -205,6 +236,9 @@ struct MMATile {
|
|||||||
STEEL_CONST int kNumFrags = kTileRows * kTileCols;
|
STEEL_CONST int kNumFrags = kTileRows * kTileCols;
|
||||||
STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag;
|
STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag;
|
||||||
|
|
||||||
|
STEEL_CONST int kRowsPerThread = kTileRows * MMAFrag_t::kElemRows;
|
||||||
|
STEEL_CONST int kColsPerThread = kTileCols * MMAFrag_t::kElemCols;
|
||||||
|
|
||||||
typedef typename MMAFrag_t::mat_type mat_type;
|
typedef typename MMAFrag_t::mat_type mat_type;
|
||||||
typedef typename MMAFrag_t::frag_type frag_type;
|
typedef typename MMAFrag_t::frag_type frag_type;
|
||||||
|
|
||||||
@ -246,6 +280,30 @@ struct MMATile {
|
|||||||
return reinterpret_cast<const thread elem_type*>(val_frags);
|
return reinterpret_cast<const thread elem_type*>(val_frags);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < kTileRows; ++i) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < kTileCols; ++j) {
|
||||||
|
MMAFrag_t::template row_reduce<Op>(
|
||||||
|
frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread]) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < kTileRows; ++i) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < kTileCols; ++j) {
|
||||||
|
MMAFrag_t::template row_bin_op<Op>(
|
||||||
|
frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename U, int w_x, int w_y, int str_x, int str_y>
|
template <typename U, int w_x, int w_y, int str_x, int str_y>
|
||||||
METAL_FUNC void load(const threadgroup U* src) {
|
METAL_FUNC void load(const threadgroup U* src) {
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
|
Loading…
Reference in New Issue
Block a user