mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-22 09:31:14 +08:00
[WIP] Update dispatch params for testing
This commit is contained in:
parent
2cd1de0e47
commit
c1dc852995
@ -91,6 +91,10 @@ template <
|
|||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
|
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
|
||||||
|
|
||||||
|
// Pacifying compiler
|
||||||
|
(void)lid;
|
||||||
|
|
||||||
|
// Move to correct block
|
||||||
ulong3 tidl{tid.x, tid.y, tid.z};
|
ulong3 tidl{tid.x, tid.y, tid.z};
|
||||||
|
|
||||||
Q += tidl.z * params->Q_strides[0] + // Batch
|
Q += tidl.z * params->Q_strides[0] + // Batch
|
||||||
@ -107,26 +111,20 @@ template <
|
|||||||
tidl.y * params->O_strides[1] + // Head
|
tidl.y * params->O_strides[1] + // Head
|
||||||
tidl.x * BQ * params->O_strides[2]; // Seqeunce
|
tidl.x * BQ * params->O_strides[2]; // Seqeunce
|
||||||
|
|
||||||
constexpr int padQ = 0; // 16 / sizeof(T);
|
// Prepare threadgroup memory
|
||||||
constexpr int padK = 0; // 16 / sizeof(T);
|
constexpr short padQ = 0; // 16 / sizeof(T);
|
||||||
constexpr int padV = 0; // 16 / sizeof(T);
|
constexpr short padK = 0; // 16 / sizeof(T);
|
||||||
|
constexpr short padV = 0; // 16 / sizeof(T);
|
||||||
|
|
||||||
// using QBlockSrcShape = CShape<BQ, BD>;
|
constexpr short LDQ_tgp = BD + padQ;
|
||||||
// using KBlockSrcShape = CShape<BK, BD>;
|
constexpr short LDK_tgp = BK + padK;
|
||||||
// using VBlockSrcShape = CShape<BK, BD>;
|
constexpr short LDV_tgp = BD + padV;
|
||||||
|
|
||||||
constexpr int LDQ_tgp = BD + padQ;
|
|
||||||
constexpr int LDK_tgp = BK + padK;
|
|
||||||
constexpr int LDV_tgp = BD + padV;
|
|
||||||
|
|
||||||
// using QBlockDstStrides = CShape<LDQ_tgp, 1>;
|
|
||||||
// using KBlockDstStrides = CShape<1, LDK_tgp>;
|
|
||||||
// using QBlockDstStrides = CShape<LDV_tgp, 1>;
|
|
||||||
|
|
||||||
threadgroup T Qs[BQ * (BD + padQ)];
|
threadgroup T Qs[BQ * (BD + padQ)];
|
||||||
threadgroup T Ks[(BK + padK) * BD];
|
threadgroup T Ks[(BK + padK) * BD];
|
||||||
threadgroup T Vs[BK * (BD + padV)];
|
threadgroup T Vs[BK * (BD + padV)];
|
||||||
|
|
||||||
|
// Prepare block loaders
|
||||||
using QBlockLoader = BlockLoaderT<
|
using QBlockLoader = BlockLoaderT<
|
||||||
/* typename T = */ T,
|
/* typename T = */ T,
|
||||||
/* short BROWS = */ BQ,
|
/* short BROWS = */ BQ,
|
||||||
@ -136,6 +134,7 @@ template <
|
|||||||
/* short reduction_dim = */ 1,
|
/* short reduction_dim = */ 1,
|
||||||
/* short tgp_size = */ WM * WN * 32>;
|
/* short tgp_size = */ WM * WN * 32>;
|
||||||
|
|
||||||
|
// K is loaded in transposed
|
||||||
using KBlockLoader = BlockLoaderT<
|
using KBlockLoader = BlockLoaderT<
|
||||||
/* typename T = */ T,
|
/* typename T = */ T,
|
||||||
/* short BROWS = */ BK,
|
/* short BROWS = */ BK,
|
||||||
@ -163,8 +162,8 @@ template <
|
|||||||
|
|
||||||
TransformScale<T> ts(static_cast<T>(params->scale));
|
TransformScale<T> ts(static_cast<T>(params->scale));
|
||||||
|
|
||||||
// MMAFrag size
|
// Prepare MMA tiles
|
||||||
constexpr short kFragSize = 8;
|
constexpr short kFragSize = 8; // MMAFrag size
|
||||||
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
|
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
|
||||||
|
|
||||||
constexpr int kNWarps = WM * WN;
|
constexpr int kNWarps = WM * WN;
|
||||||
@ -189,35 +188,40 @@ template <
|
|||||||
|
|
||||||
Otile.clear();
|
Otile.clear();
|
||||||
|
|
||||||
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
|
// Prepare mma tile offsets
|
||||||
short sm = simd_coord.y;
|
const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
|
||||||
short sn = simd_coord.x;
|
const short sm = simd_coord.y;
|
||||||
short tm = kFragSize * TQ * simd_group_id;
|
const short sn = simd_coord.x;
|
||||||
|
const short tm = kFragSize * TQ * simd_group_id;
|
||||||
|
|
||||||
short Qs_offset = (tm + sm) * LDQ_tgp + sn;
|
const short Qs_offset = (tm + sm) * LDQ_tgp + sn;
|
||||||
short Ks_offset = sm * LDK_tgp + sn;
|
const short Ks_offset = sm * LDK_tgp + sn;
|
||||||
short Vs_offset = sm * LDV_tgp + sn;
|
const short Vs_offset = sm * LDV_tgp + sn;
|
||||||
|
|
||||||
constexpr int Qs_tile_stride = kFragSize;
|
constexpr short Qs_tile_stride = kFragSize;
|
||||||
constexpr int Ks_tile_stride = kFragSize * LDK_tgp;
|
constexpr short Ks_tile_stride = kFragSize * LDK_tgp;
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Load Q blocks apply scale
|
||||||
loader_q.load_unsafe();
|
loader_q.load_unsafe();
|
||||||
loader_q.apply_inplace_op(ts);
|
loader_q.apply_inplace_op(ts);
|
||||||
|
|
||||||
constexpr int kRowsPT = decltype(Stile)::kRowsPerThread;
|
// Init row reduction variables
|
||||||
|
constexpr short kRowsPT = decltype(Stile)::kRowsPerThread;
|
||||||
|
|
||||||
AccumType max_score[kRowsPT];
|
AccumType max_score[kRowsPT];
|
||||||
AccumType sum_score[kRowsPT] = {0};
|
AccumType sum_score[kRowsPT] = {0};
|
||||||
|
|
||||||
|
// Init to -Inf
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for (short i = 0; i < kRowsPT; ++i) {
|
for (short i = 0; i < kRowsPT; ++i) {
|
||||||
max_score[i] = Limits<AccumType>::min;
|
max_score[i] = Limits<AccumType>::min;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Loop over KV seq length
|
||||||
for (int kb = 0; kb < params->NK; kb++) {
|
for (int kb = 0; kb < params->NK; kb++) {
|
||||||
// Load Q and K blocks and apply scale
|
// Load K block and apply scale
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
loader_k.load_unsafe();
|
loader_k.load_unsafe();
|
||||||
|
|
||||||
@ -246,15 +250,15 @@ template <
|
|||||||
|
|
||||||
// Do softmax
|
// Do softmax
|
||||||
|
|
||||||
// Row max
|
// Temp variables
|
||||||
AccumType new_max[kRowsPT];
|
AccumType new_max[kRowsPT];
|
||||||
AccumType factor[kRowsPT];
|
AccumType factor[kRowsPT];
|
||||||
|
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for (short i = 0; i < kRowsPT; ++i) {
|
for (short i = 0; i < kRowsPT; ++i) {
|
||||||
new_max[i] = max_score[i];
|
new_max[i] = max_score[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Row max
|
||||||
Stile.template row_reduce<MaxOp>(new_max);
|
Stile.template row_reduce<MaxOp>(new_max);
|
||||||
|
|
||||||
// exp(Si - rowmax(Si))
|
// exp(Si - rowmax(Si))
|
||||||
@ -265,6 +269,8 @@ template <
|
|||||||
for (short i = 0; i < kRowsPT; ++i) {
|
for (short i = 0; i < kRowsPT; ++i) {
|
||||||
factor[i] = fast::exp(max_score[i] - new_max[i]);
|
factor[i] = fast::exp(max_score[i] - new_max[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Save max for next iteration
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for (short i = 0; i < kRowsPT; ++i) {
|
for (short i = 0; i < kRowsPT; ++i) {
|
||||||
max_score[i] = new_max[i];
|
max_score[i] = new_max[i];
|
||||||
@ -283,20 +289,21 @@ template <
|
|||||||
// Update O
|
// Update O
|
||||||
Otile.template row_bin_op<MulOp>(factor);
|
Otile.template row_bin_op<MulOp>(factor);
|
||||||
|
|
||||||
// Do O = S @ V
|
// Load V into registers
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
Vtile.template load<T, 1, 1, LDV_tgp, 1>(&Vs[Vs_offset]);
|
Vtile.template load<T, 1, 1, LDV_tgp, 1>(&Vs[Vs_offset]);
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Do O = S @ V
|
||||||
tile_matmad(Otile, Stile, Vtile, Otile);
|
tile_matmad(Otile, Stile, Vtile, Otile);
|
||||||
|
|
||||||
// Prepare for next iteration
|
// Prepare for next iteration
|
||||||
// loader_q.next();
|
|
||||||
loader_k.next();
|
loader_k.next();
|
||||||
loader_v.next();
|
loader_v.next();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Normalize output
|
||||||
Otile.template row_bin_op<DivOp>(sum_score);
|
Otile.template row_bin_op<DivOp>(sum_score);
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
@ -600,7 +600,7 @@ array scaled_dot_product_attention(
|
|||||||
* * dtype is not fp32 or fp16
|
* * dtype is not fp32 or fp16
|
||||||
*/
|
*/
|
||||||
|
|
||||||
int threshold = 1024; // TODO: Fix after dev
|
int threshold = 32; // TODO: Fix after dev
|
||||||
if (memory_efficient_threshold.has_value()) {
|
if (memory_efficient_threshold.has_value()) {
|
||||||
threshold = std::max(1, memory_efficient_threshold.value());
|
threshold = std::max(1, memory_efficient_threshold.value());
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user