mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 16:51:15 +08:00
[WIP] Update dispatch params for testing
This commit is contained in:
parent
2cd1de0e47
commit
c1dc852995
@ -91,6 +91,10 @@ template <
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
|
||||
|
||||
// Pacifying compiler
|
||||
(void)lid;
|
||||
|
||||
// Move to correct block
|
||||
ulong3 tidl{tid.x, tid.y, tid.z};
|
||||
|
||||
Q += tidl.z * params->Q_strides[0] + // Batch
|
||||
@ -107,26 +111,20 @@ template <
|
||||
tidl.y * params->O_strides[1] + // Head
|
||||
tidl.x * BQ * params->O_strides[2]; // Seqeunce
|
||||
|
||||
constexpr int padQ = 0; // 16 / sizeof(T);
|
||||
constexpr int padK = 0; // 16 / sizeof(T);
|
||||
constexpr int padV = 0; // 16 / sizeof(T);
|
||||
// Prepare threadgroup memory
|
||||
constexpr short padQ = 0; // 16 / sizeof(T);
|
||||
constexpr short padK = 0; // 16 / sizeof(T);
|
||||
constexpr short padV = 0; // 16 / sizeof(T);
|
||||
|
||||
// using QBlockSrcShape = CShape<BQ, BD>;
|
||||
// using KBlockSrcShape = CShape<BK, BD>;
|
||||
// using VBlockSrcShape = CShape<BK, BD>;
|
||||
|
||||
constexpr int LDQ_tgp = BD + padQ;
|
||||
constexpr int LDK_tgp = BK + padK;
|
||||
constexpr int LDV_tgp = BD + padV;
|
||||
|
||||
// using QBlockDstStrides = CShape<LDQ_tgp, 1>;
|
||||
// using KBlockDstStrides = CShape<1, LDK_tgp>;
|
||||
// using QBlockDstStrides = CShape<LDV_tgp, 1>;
|
||||
constexpr short LDQ_tgp = BD + padQ;
|
||||
constexpr short LDK_tgp = BK + padK;
|
||||
constexpr short LDV_tgp = BD + padV;
|
||||
|
||||
threadgroup T Qs[BQ * (BD + padQ)];
|
||||
threadgroup T Ks[(BK + padK) * BD];
|
||||
threadgroup T Vs[BK * (BD + padV)];
|
||||
|
||||
// Prepare block loaders
|
||||
using QBlockLoader = BlockLoaderT<
|
||||
/* typename T = */ T,
|
||||
/* short BROWS = */ BQ,
|
||||
@ -136,6 +134,7 @@ template <
|
||||
/* short reduction_dim = */ 1,
|
||||
/* short tgp_size = */ WM * WN * 32>;
|
||||
|
||||
// K is loaded in transposed
|
||||
using KBlockLoader = BlockLoaderT<
|
||||
/* typename T = */ T,
|
||||
/* short BROWS = */ BK,
|
||||
@ -163,8 +162,8 @@ template <
|
||||
|
||||
TransformScale<T> ts(static_cast<T>(params->scale));
|
||||
|
||||
// MMAFrag size
|
||||
constexpr short kFragSize = 8;
|
||||
// Prepare MMA tiles
|
||||
constexpr short kFragSize = 8; // MMAFrag size
|
||||
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
|
||||
|
||||
constexpr int kNWarps = WM * WN;
|
||||
@ -189,35 +188,40 @@ template <
|
||||
|
||||
Otile.clear();
|
||||
|
||||
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
|
||||
short sm = simd_coord.y;
|
||||
short sn = simd_coord.x;
|
||||
short tm = kFragSize * TQ * simd_group_id;
|
||||
// Prepare mma tile offsets
|
||||
const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
|
||||
const short sm = simd_coord.y;
|
||||
const short sn = simd_coord.x;
|
||||
const short tm = kFragSize * TQ * simd_group_id;
|
||||
|
||||
short Qs_offset = (tm + sm) * LDQ_tgp + sn;
|
||||
short Ks_offset = sm * LDK_tgp + sn;
|
||||
short Vs_offset = sm * LDV_tgp + sn;
|
||||
const short Qs_offset = (tm + sm) * LDQ_tgp + sn;
|
||||
const short Ks_offset = sm * LDK_tgp + sn;
|
||||
const short Vs_offset = sm * LDV_tgp + sn;
|
||||
|
||||
constexpr int Qs_tile_stride = kFragSize;
|
||||
constexpr int Ks_tile_stride = kFragSize * LDK_tgp;
|
||||
constexpr short Qs_tile_stride = kFragSize;
|
||||
constexpr short Ks_tile_stride = kFragSize * LDK_tgp;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load Q blocks apply scale
|
||||
loader_q.load_unsafe();
|
||||
loader_q.apply_inplace_op(ts);
|
||||
|
||||
constexpr int kRowsPT = decltype(Stile)::kRowsPerThread;
|
||||
// Init row reduction variables
|
||||
constexpr short kRowsPT = decltype(Stile)::kRowsPerThread;
|
||||
|
||||
AccumType max_score[kRowsPT];
|
||||
AccumType sum_score[kRowsPT] = {0};
|
||||
|
||||
// Init to -Inf
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
max_score[i] = Limits<AccumType>::min;
|
||||
}
|
||||
|
||||
// Loop over KV seq length
|
||||
for (int kb = 0; kb < params->NK; kb++) {
|
||||
// Load Q and K blocks and apply scale
|
||||
// Load K block and apply scale
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_k.load_unsafe();
|
||||
|
||||
@ -246,15 +250,15 @@ template <
|
||||
|
||||
// Do softmax
|
||||
|
||||
// Row max
|
||||
// Temp variables
|
||||
AccumType new_max[kRowsPT];
|
||||
AccumType factor[kRowsPT];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
new_max[i] = max_score[i];
|
||||
}
|
||||
|
||||
// Row max
|
||||
Stile.template row_reduce<MaxOp>(new_max);
|
||||
|
||||
// exp(Si - rowmax(Si))
|
||||
@ -265,6 +269,8 @@ template <
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
factor[i] = fast::exp(max_score[i] - new_max[i]);
|
||||
}
|
||||
|
||||
// Save max for next iteration
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
max_score[i] = new_max[i];
|
||||
@ -283,20 +289,21 @@ template <
|
||||
// Update O
|
||||
Otile.template row_bin_op<MulOp>(factor);
|
||||
|
||||
// Do O = S @ V
|
||||
// Load V into registers
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
Vtile.template load<T, 1, 1, LDV_tgp, 1>(&Vs[Vs_offset]);
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Do O = S @ V
|
||||
tile_matmad(Otile, Stile, Vtile, Otile);
|
||||
|
||||
// Prepare for next iteration
|
||||
// loader_q.next();
|
||||
loader_k.next();
|
||||
loader_v.next();
|
||||
}
|
||||
|
||||
// Normalize output
|
||||
Otile.template row_bin_op<DivOp>(sum_score);
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
|
@ -600,7 +600,7 @@ array scaled_dot_product_attention(
|
||||
* * dtype is not fp32 or fp16
|
||||
*/
|
||||
|
||||
int threshold = 1024; // TODO: Fix after dev
|
||||
int threshold = 32; // TODO: Fix after dev
|
||||
if (memory_efficient_threshold.has_value()) {
|
||||
threshold = std::max(1, memory_efficient_threshold.value());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user