[WIP] Update dispatch params for testing

This commit is contained in:
Jagrit Digani 2024-11-20 09:39:48 -08:00
parent 2cd1de0e47
commit c1dc852995
2 changed files with 39 additions and 32 deletions

View File

@ -91,6 +91,10 @@ template <
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
// Pacifying compiler
(void)lid;
// Move to correct block
ulong3 tidl{tid.x, tid.y, tid.z}; ulong3 tidl{tid.x, tid.y, tid.z};
Q += tidl.z * params->Q_strides[0] + // Batch Q += tidl.z * params->Q_strides[0] + // Batch
@ -107,26 +111,20 @@ template <
tidl.y * params->O_strides[1] + // Head tidl.y * params->O_strides[1] + // Head
tidl.x * BQ * params->O_strides[2]; // Seqeunce tidl.x * BQ * params->O_strides[2]; // Seqeunce
constexpr int padQ = 0; // 16 / sizeof(T); // Prepare threadgroup memory
constexpr int padK = 0; // 16 / sizeof(T); constexpr short padQ = 0; // 16 / sizeof(T);
constexpr int padV = 0; // 16 / sizeof(T); constexpr short padK = 0; // 16 / sizeof(T);
constexpr short padV = 0; // 16 / sizeof(T);
// using QBlockSrcShape = CShape<BQ, BD>; constexpr short LDQ_tgp = BD + padQ;
// using KBlockSrcShape = CShape<BK, BD>; constexpr short LDK_tgp = BK + padK;
// using VBlockSrcShape = CShape<BK, BD>; constexpr short LDV_tgp = BD + padV;
constexpr int LDQ_tgp = BD + padQ;
constexpr int LDK_tgp = BK + padK;
constexpr int LDV_tgp = BD + padV;
// using QBlockDstStrides = CShape<LDQ_tgp, 1>;
// using KBlockDstStrides = CShape<1, LDK_tgp>;
// using QBlockDstStrides = CShape<LDV_tgp, 1>;
threadgroup T Qs[BQ * (BD + padQ)]; threadgroup T Qs[BQ * (BD + padQ)];
threadgroup T Ks[(BK + padK) * BD]; threadgroup T Ks[(BK + padK) * BD];
threadgroup T Vs[BK * (BD + padV)]; threadgroup T Vs[BK * (BD + padV)];
// Prepare block loaders
using QBlockLoader = BlockLoaderT< using QBlockLoader = BlockLoaderT<
/* typename T = */ T, /* typename T = */ T,
/* short BROWS = */ BQ, /* short BROWS = */ BQ,
@ -136,6 +134,7 @@ template <
/* short reduction_dim = */ 1, /* short reduction_dim = */ 1,
/* short tgp_size = */ WM * WN * 32>; /* short tgp_size = */ WM * WN * 32>;
// K is loaded in transposed
using KBlockLoader = BlockLoaderT< using KBlockLoader = BlockLoaderT<
/* typename T = */ T, /* typename T = */ T,
/* short BROWS = */ BK, /* short BROWS = */ BK,
@ -163,8 +162,8 @@ template <
TransformScale<T> ts(static_cast<T>(params->scale)); TransformScale<T> ts(static_cast<T>(params->scale));
// MMAFrag size // Prepare MMA tiles
constexpr short kFragSize = 8; constexpr short kFragSize = 8; // MMAFrag size
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>; using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
constexpr int kNWarps = WM * WN; constexpr int kNWarps = WM * WN;
@ -189,35 +188,40 @@ template <
Otile.clear(); Otile.clear();
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); // Prepare mma tile offsets
short sm = simd_coord.y; const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
short sn = simd_coord.x; const short sm = simd_coord.y;
short tm = kFragSize * TQ * simd_group_id; const short sn = simd_coord.x;
const short tm = kFragSize * TQ * simd_group_id;
short Qs_offset = (tm + sm) * LDQ_tgp + sn; const short Qs_offset = (tm + sm) * LDQ_tgp + sn;
short Ks_offset = sm * LDK_tgp + sn; const short Ks_offset = sm * LDK_tgp + sn;
short Vs_offset = sm * LDV_tgp + sn; const short Vs_offset = sm * LDV_tgp + sn;
constexpr int Qs_tile_stride = kFragSize; constexpr short Qs_tile_stride = kFragSize;
constexpr int Ks_tile_stride = kFragSize * LDK_tgp; constexpr short Ks_tile_stride = kFragSize * LDK_tgp;
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// Load Q blocks apply scale
loader_q.load_unsafe(); loader_q.load_unsafe();
loader_q.apply_inplace_op(ts); loader_q.apply_inplace_op(ts);
constexpr int kRowsPT = decltype(Stile)::kRowsPerThread; // Init row reduction variables
constexpr short kRowsPT = decltype(Stile)::kRowsPerThread;
AccumType max_score[kRowsPT]; AccumType max_score[kRowsPT];
AccumType sum_score[kRowsPT] = {0}; AccumType sum_score[kRowsPT] = {0};
// Init to -Inf
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) { for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = Limits<AccumType>::min; max_score[i] = Limits<AccumType>::min;
} }
// Loop over KV seq length
for (int kb = 0; kb < params->NK; kb++) { for (int kb = 0; kb < params->NK; kb++) {
// Load Q and K blocks and apply scale // Load K block and apply scale
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
loader_k.load_unsafe(); loader_k.load_unsafe();
@ -246,15 +250,15 @@ template <
// Do softmax // Do softmax
// Row max // Temp variables
AccumType new_max[kRowsPT]; AccumType new_max[kRowsPT];
AccumType factor[kRowsPT]; AccumType factor[kRowsPT];
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) { for (short i = 0; i < kRowsPT; ++i) {
new_max[i] = max_score[i]; new_max[i] = max_score[i];
} }
// Row max
Stile.template row_reduce<MaxOp>(new_max); Stile.template row_reduce<MaxOp>(new_max);
// exp(Si - rowmax(Si)) // exp(Si - rowmax(Si))
@ -265,6 +269,8 @@ template <
for (short i = 0; i < kRowsPT; ++i) { for (short i = 0; i < kRowsPT; ++i) {
factor[i] = fast::exp(max_score[i] - new_max[i]); factor[i] = fast::exp(max_score[i] - new_max[i]);
} }
// Save max for next iteration
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) { for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = new_max[i]; max_score[i] = new_max[i];
@ -283,20 +289,21 @@ template <
// Update O // Update O
Otile.template row_bin_op<MulOp>(factor); Otile.template row_bin_op<MulOp>(factor);
// Do O = S @ V // Load V into registers
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
Vtile.template load<T, 1, 1, LDV_tgp, 1>(&Vs[Vs_offset]); Vtile.template load<T, 1, 1, LDV_tgp, 1>(&Vs[Vs_offset]);
simdgroup_barrier(mem_flags::mem_none); simdgroup_barrier(mem_flags::mem_none);
// Do O = S @ V
tile_matmad(Otile, Stile, Vtile, Otile); tile_matmad(Otile, Stile, Vtile, Otile);
// Prepare for next iteration // Prepare for next iteration
// loader_q.next();
loader_k.next(); loader_k.next();
loader_v.next(); loader_v.next();
} }
// Normalize output
Otile.template row_bin_op<DivOp>(sum_score); Otile.template row_bin_op<DivOp>(sum_score);
threadgroup_barrier(mem_flags::mem_none); threadgroup_barrier(mem_flags::mem_none);

View File

@ -600,7 +600,7 @@ array scaled_dot_product_attention(
* * dtype is not fp32 or fp16 * * dtype is not fp32 or fp16
*/ */
int threshold = 1024; // TODO: Fix after dev int threshold = 32; // TODO: Fix after dev
if (memory_efficient_threshold.has_value()) { if (memory_efficient_threshold.has_value()) {
threshold = std::max(1, memory_efficient_threshold.value()); threshold = std::max(1, memory_efficient_threshold.value());
} }