mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-12 11:01:15 +08:00
[WIP] Add support for unaligned seq lengths - still looks messy
This commit is contained in:
parent
c1dc852995
commit
83c4f6bde6
@ -6,18 +6,8 @@ using namespace mlx::steel;
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
constant bool has_batch [[function_constant(10)]];
|
||||
|
||||
constant bool use_out_source [[function_constant(100)]];
|
||||
constant bool do_axpby [[function_constant(110)]];
|
||||
|
||||
constant bool align_M [[function_constant(200)]];
|
||||
constant bool align_N [[function_constant(201)]];
|
||||
constant bool align_K [[function_constant(202)]];
|
||||
|
||||
constant bool do_gather [[function_constant(300)]];
|
||||
|
||||
constant bool gather_bias = do_gather && use_out_source;
|
||||
constant bool align_Q [[function_constant(200)]];
|
||||
constant bool align_K [[function_constant(201)]];
|
||||
|
||||
template <typename T>
|
||||
struct TransformScale {
|
||||
@ -204,7 +194,11 @@ template <
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load Q blocks apply scale
|
||||
loader_q.load_unsafe();
|
||||
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
|
||||
loader_q.load_safe(short2(BD, params->qL - params->NQ_aligned * BQ));
|
||||
} else {
|
||||
loader_q.load_unsafe();
|
||||
}
|
||||
loader_q.apply_inplace_op(ts);
|
||||
|
||||
// Init row reduction variables
|
||||
@ -223,7 +217,11 @@ template <
|
||||
for (int kb = 0; kb < params->NK; kb++) {
|
||||
// Load K block and apply scale
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_k.load_unsafe();
|
||||
if (!align_K && kb == (params->NK_aligned)) {
|
||||
loader_k.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
|
||||
} else {
|
||||
loader_k.load_unsafe();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
@ -243,10 +241,36 @@ template <
|
||||
tile_matmad(Stile, Qtile, Ktile, Stile);
|
||||
}
|
||||
|
||||
// Mask out of length sequence
|
||||
if (!align_K && kb == (params->NK_aligned)) {
|
||||
using stile_t = decltype(Stile);
|
||||
using selem_t = typename stile_t::elem_type;
|
||||
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
|
||||
const short lim = params->kL - params->NK_aligned * BK;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < stile_t::kTileRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < stile_t::kTileCols; j++) {
|
||||
short col_pos = sn + (j * stile_t::kFragCols);
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
|
||||
if ((col_pos + jj) >= lim) {
|
||||
Stile.frag_at(i, j)[jj] = neg_inf;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Load V blocks
|
||||
loader_v.load_unsafe();
|
||||
if (!align_K && kb == (params->NK_aligned)) {
|
||||
loader_v.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
|
||||
} else {
|
||||
loader_v.load_unsafe();
|
||||
}
|
||||
|
||||
// Do softmax
|
||||
|
||||
@ -309,5 +333,16 @@ template <
|
||||
|
||||
// Store results
|
||||
O += (tm + sm) * params->O_strides[2] + sn;
|
||||
Otile.template store<T, 1, 1>(O, params->O_strides[2]);
|
||||
|
||||
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
|
||||
auto dst_tile_dims =
|
||||
short2(BD - sn, params->qL - BQ * params->NQ_aligned - (tm + sm));
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
|
||||
Otile.template store_safe<T, 1, 1>(O, params->O_strides[2], dst_tile_dims);
|
||||
} else {
|
||||
Otile.template store<T, 1, 1>(O, params->O_strides[2]);
|
||||
}
|
||||
}
|
||||
|
@ -21,11 +21,9 @@
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_attn_shapes_helper(iname, itype) \
|
||||
instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \
|
||||
instantiate_attn(iname, itype, 32, 16, 80, 4, 1) \
|
||||
instantiate_attn(iname, itype, 32, 32, 64, 4, 1) \
|
||||
instantiate_attn(iname, itype, 32, 16, 64, 4, 1) \
|
||||
|
||||
instantiate_attn(iname, itype, 32, 32, 128, 4, 1) \
|
||||
instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \
|
||||
instantiate_attn(iname, itype, 32, 32, 64, 4, 1)
|
||||
|
||||
instantiate_attn_shapes_helper(float16, half);
|
||||
// instantiate_attn_shapes_helper(bfloat16, bfloat16_t);
|
||||
|
@ -12,15 +12,20 @@ namespace steel {
|
||||
struct AttnParams {
|
||||
int B; ///< Batch Size
|
||||
int H; ///< Heads
|
||||
int L; ///< Sequence Length
|
||||
int D; ///< Head Dim
|
||||
|
||||
int qL; ///< Query Sequence Length
|
||||
int kL; ///< Key Sequence Length
|
||||
|
||||
int gqa_factor; ///< Group Query factor
|
||||
float scale; ///< Attention scale
|
||||
|
||||
int NQ; ///< Number of query blocks
|
||||
int NK; ///< Number of key/value blocks
|
||||
|
||||
int NQ_aligned; ///< Number of full query blocks
|
||||
int NK_aligned; ///< Number of full key/value blocks
|
||||
|
||||
size_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
|
||||
size_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
|
||||
size_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
|
||||
|
@ -33,34 +33,66 @@ void sdpa_full_self_attention_metal(
|
||||
int bk = 32;
|
||||
int bd = q.shape(-1);
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "steel_attention_" << type_to_name(q) << "_bq" << bq << "_bk" << bk
|
||||
<< "_bd" << bd << "_wm" << wm << "_wn" << wn;
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int B = q.shape(0);
|
||||
int H = q.shape(1);
|
||||
int L = q.shape(2);
|
||||
int D = q.shape(3);
|
||||
int gqa_factor = q.shape(1) / k.shape(1);
|
||||
|
||||
int NQ = (L + bq - 1) / bq;
|
||||
int NK = (L + bk - 1) / bk;
|
||||
int qL = q.shape(2);
|
||||
int kL = k.shape(2);
|
||||
|
||||
const bool align_Q = (qL % bq) == 0;
|
||||
const bool align_K = (kL % bk) == 0;
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&align_Q, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_K, MTL::DataType::DataTypeBool, 201},
|
||||
};
|
||||
|
||||
std::ostringstream kname;
|
||||
// clang-format off
|
||||
kname << "steel_attention_"
|
||||
<< type_to_name(q)
|
||||
<< "_bq" << bq
|
||||
<< "_bk" << bk
|
||||
<< "_bd" << bd
|
||||
<< "_wm" << wm << "_wn" << wn; // clang-format on
|
||||
|
||||
std::string base_name = kname.str();
|
||||
|
||||
// clang-format off
|
||||
kname << "_align_Q_" << (align_Q ? 't' : 'n')
|
||||
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
const int NQ = (qL + bq - 1) / bq;
|
||||
const int NK = (kL + bk - 1) / bk;
|
||||
|
||||
const int NQ_aligned = qL / bq;
|
||||
const int NK_aligned = kL / bk;
|
||||
|
||||
AttnParams params{
|
||||
/* int B = */ B,
|
||||
/* int H = */ H,
|
||||
/* int L = */ L,
|
||||
/* int D = */ D,
|
||||
|
||||
/* int qL = */ qL,
|
||||
/* int kL = */ kL,
|
||||
|
||||
/* int gqa_factor = */ gqa_factor,
|
||||
/* float scale = */ scale,
|
||||
|
||||
/* int NQ = */ NQ,
|
||||
/* int NK = */ NK,
|
||||
|
||||
/* int NQ_aligned = */ NQ_aligned,
|
||||
/* int NK_aligned = */ NK_aligned,
|
||||
|
||||
/* size_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
|
||||
/* size_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
|
||||
/* size_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
|
||||
|
Loading…
Reference in New Issue
Block a user