[WIP] Update dispatch params for testing

This commit is contained in:
Jagrit Digani 2024-11-20 09:39:48 -08:00
parent 2cd1de0e47
commit c1dc852995
2 changed files with 39 additions and 32 deletions

View File

@ -91,6 +91,10 @@ template <
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
// Pacifying compiler
(void)lid;
// Move to correct block
ulong3 tidl{tid.x, tid.y, tid.z};
Q += tidl.z * params->Q_strides[0] + // Batch
@ -107,26 +111,20 @@ template <
tidl.y * params->O_strides[1] + // Head
tidl.x * BQ * params->O_strides[2]; // Seqeunce
constexpr int padQ = 0; // 16 / sizeof(T);
constexpr int padK = 0; // 16 / sizeof(T);
constexpr int padV = 0; // 16 / sizeof(T);
// Prepare threadgroup memory
constexpr short padQ = 0; // 16 / sizeof(T);
constexpr short padK = 0; // 16 / sizeof(T);
constexpr short padV = 0; // 16 / sizeof(T);
// using QBlockSrcShape = CShape<BQ, BD>;
// using KBlockSrcShape = CShape<BK, BD>;
// using VBlockSrcShape = CShape<BK, BD>;
constexpr int LDQ_tgp = BD + padQ;
constexpr int LDK_tgp = BK + padK;
constexpr int LDV_tgp = BD + padV;
// using QBlockDstStrides = CShape<LDQ_tgp, 1>;
// using KBlockDstStrides = CShape<1, LDK_tgp>;
// using QBlockDstStrides = CShape<LDV_tgp, 1>;
constexpr short LDQ_tgp = BD + padQ;
constexpr short LDK_tgp = BK + padK;
constexpr short LDV_tgp = BD + padV;
threadgroup T Qs[BQ * (BD + padQ)];
threadgroup T Ks[(BK + padK) * BD];
threadgroup T Vs[BK * (BD + padV)];
// Prepare block loaders
using QBlockLoader = BlockLoaderT<
/* typename T = */ T,
/* short BROWS = */ BQ,
@ -136,6 +134,7 @@ template <
/* short reduction_dim = */ 1,
/* short tgp_size = */ WM * WN * 32>;
// K is loaded in transposed
using KBlockLoader = BlockLoaderT<
/* typename T = */ T,
/* short BROWS = */ BK,
@ -163,8 +162,8 @@ template <
TransformScale<T> ts(static_cast<T>(params->scale));
// MMAFrag size
constexpr short kFragSize = 8;
// Prepare MMA tiles
constexpr short kFragSize = 8; // MMAFrag size
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
constexpr int kNWarps = WM * WN;
@ -189,35 +188,40 @@ template <
Otile.clear();
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
short sm = simd_coord.y;
short sn = simd_coord.x;
short tm = kFragSize * TQ * simd_group_id;
// Prepare mma tile offsets
const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
const short sm = simd_coord.y;
const short sn = simd_coord.x;
const short tm = kFragSize * TQ * simd_group_id;
short Qs_offset = (tm + sm) * LDQ_tgp + sn;
short Ks_offset = sm * LDK_tgp + sn;
short Vs_offset = sm * LDV_tgp + sn;
const short Qs_offset = (tm + sm) * LDQ_tgp + sn;
const short Ks_offset = sm * LDK_tgp + sn;
const short Vs_offset = sm * LDV_tgp + sn;
constexpr int Qs_tile_stride = kFragSize;
constexpr int Ks_tile_stride = kFragSize * LDK_tgp;
constexpr short Qs_tile_stride = kFragSize;
constexpr short Ks_tile_stride = kFragSize * LDK_tgp;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load Q blocks apply scale
loader_q.load_unsafe();
loader_q.apply_inplace_op(ts);
constexpr int kRowsPT = decltype(Stile)::kRowsPerThread;
// Init row reduction variables
constexpr short kRowsPT = decltype(Stile)::kRowsPerThread;
AccumType max_score[kRowsPT];
AccumType sum_score[kRowsPT] = {0};
// Init to -Inf
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = Limits<AccumType>::min;
}
// Loop over KV seq length
for (int kb = 0; kb < params->NK; kb++) {
// Load Q and K blocks and apply scale
// Load K block and apply scale
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_k.load_unsafe();
@ -246,15 +250,15 @@ template <
// Do softmax
// Row max
// Temp variables
AccumType new_max[kRowsPT];
AccumType factor[kRowsPT];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
new_max[i] = max_score[i];
}
// Row max
Stile.template row_reduce<MaxOp>(new_max);
// exp(Si - rowmax(Si))
@ -265,6 +269,8 @@ template <
for (short i = 0; i < kRowsPT; ++i) {
factor[i] = fast::exp(max_score[i] - new_max[i]);
}
// Save max for next iteration
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = new_max[i];
@ -283,20 +289,21 @@ template <
// Update O
Otile.template row_bin_op<MulOp>(factor);
// Do O = S @ V
// Load V into registers
threadgroup_barrier(mem_flags::mem_threadgroup);
Vtile.template load<T, 1, 1, LDV_tgp, 1>(&Vs[Vs_offset]);
simdgroup_barrier(mem_flags::mem_none);
// Do O = S @ V
tile_matmad(Otile, Stile, Vtile, Otile);
// Prepare for next iteration
// loader_q.next();
loader_k.next();
loader_v.next();
}
// Normalize output
Otile.template row_bin_op<DivOp>(sum_score);
threadgroup_barrier(mem_flags::mem_none);

View File

@ -600,7 +600,7 @@ array scaled_dot_product_attention(
* * dtype is not fp32 or fp16
*/
int threshold = 1024; // TODO: Fix after dev
int threshold = 32; // TODO: Fix after dev
if (memory_efficient_threshold.has_value()) {
threshold = std::max(1, memory_efficient_threshold.value());
}