[WIP] Add support for unaligned seq lengths - still looks messy

This commit is contained in:
Jagrit Digani 2024-11-20 15:16:05 -08:00
parent c1dc852995
commit 83c4f6bde6
4 changed files with 104 additions and 34 deletions

View File

@ -6,18 +6,8 @@ using namespace mlx::steel;
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
constant bool has_batch [[function_constant(10)]];
constant bool use_out_source [[function_constant(100)]];
constant bool do_axpby [[function_constant(110)]];
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
constant bool do_gather [[function_constant(300)]];
constant bool gather_bias = do_gather && use_out_source;
constant bool align_Q [[function_constant(200)]];
constant bool align_K [[function_constant(201)]];
template <typename T>
struct TransformScale {
@ -204,7 +194,11 @@ template <
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load Q blocks apply scale
loader_q.load_unsafe();
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
loader_q.load_safe(short2(BD, params->qL - params->NQ_aligned * BQ));
} else {
loader_q.load_unsafe();
}
loader_q.apply_inplace_op(ts);
// Init row reduction variables
@ -223,7 +217,11 @@ template <
for (int kb = 0; kb < params->NK; kb++) {
// Load K block and apply scale
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_k.load_unsafe();
if (!align_K && kb == (params->NK_aligned)) {
loader_k.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
} else {
loader_k.load_unsafe();
}
threadgroup_barrier(mem_flags::mem_threadgroup);
@ -243,10 +241,36 @@ template <
tile_matmad(Stile, Qtile, Ktile, Stile);
}
// Mask out of length sequence
if (!align_K && kb == (params->NK_aligned)) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
const short lim = params->kL - params->NK_aligned * BK;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < stile_t::kTileCols; j++) {
short col_pos = sn + (j * stile_t::kFragCols);
STEEL_PRAGMA_UNROLL
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
if ((col_pos + jj) >= lim) {
Stile.frag_at(i, j)[jj] = neg_inf;
}
}
}
}
}
simdgroup_barrier(mem_flags::mem_none);
// Load V blocks
loader_v.load_unsafe();
if (!align_K && kb == (params->NK_aligned)) {
loader_v.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
} else {
loader_v.load_unsafe();
}
// Do softmax
@ -309,5 +333,16 @@ template <
// Store results
O += (tm + sm) * params->O_strides[2] + sn;
Otile.template store<T, 1, 1>(O, params->O_strides[2]);
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
auto dst_tile_dims =
short2(BD - sn, params->qL - BQ * params->NQ_aligned - (tm + sm));
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
Otile.template store_safe<T, 1, 1>(O, params->O_strides[2], dst_tile_dims);
} else {
Otile.template store<T, 1, 1>(O, params->O_strides[2]);
}
}

View File

@ -21,11 +21,9 @@
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_attn_shapes_helper(iname, itype) \
instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \
instantiate_attn(iname, itype, 32, 16, 80, 4, 1) \
instantiate_attn(iname, itype, 32, 32, 64, 4, 1) \
instantiate_attn(iname, itype, 32, 16, 64, 4, 1) \
instantiate_attn(iname, itype, 32, 32, 128, 4, 1) \
instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \
instantiate_attn(iname, itype, 32, 32, 64, 4, 1)
instantiate_attn_shapes_helper(float16, half);
// instantiate_attn_shapes_helper(bfloat16, bfloat16_t);

View File

@ -12,15 +12,20 @@ namespace steel {
struct AttnParams {
int B; ///< Batch Size
int H; ///< Heads
int L; ///< Sequence Length
int D; ///< Head Dim
int qL; ///< Query Sequence Length
int kL; ///< Key Sequence Length
int gqa_factor; ///< Group Query factor
float scale; ///< Attention scale
int NQ; ///< Number of query blocks
int NK; ///< Number of key/value blocks
int NQ_aligned; ///< Number of full query blocks
int NK_aligned; ///< Number of full key/value blocks
size_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
size_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
size_t V_strides[3]; ///< Value strides (B, H, L, D = 1)

View File

@ -33,34 +33,66 @@ void sdpa_full_self_attention_metal(
int bk = 32;
int bd = q.shape(-1);
std::ostringstream kname;
kname << "steel_attention_" << type_to_name(q) << "_bq" << bq << "_bk" << bk
<< "_bd" << bd << "_wm" << wm << "_wn" << wn;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
int B = q.shape(0);
int H = q.shape(1);
int L = q.shape(2);
int D = q.shape(3);
int gqa_factor = q.shape(1) / k.shape(1);
int NQ = (L + bq - 1) / bq;
int NK = (L + bk - 1) / bk;
int qL = q.shape(2);
int kL = k.shape(2);
const bool align_Q = (qL % bq) == 0;
const bool align_K = (kL % bk) == 0;
metal::MTLFCList func_consts = {
{&align_Q, MTL::DataType::DataTypeBool, 200},
{&align_K, MTL::DataType::DataTypeBool, 201},
};
std::ostringstream kname;
// clang-format off
kname << "steel_attention_"
<< type_to_name(q)
<< "_bq" << bq
<< "_bk" << bk
<< "_bd" << bd
<< "_wm" << wm << "_wn" << wn; // clang-format on
std::string base_name = kname.str();
// clang-format off
kname << "_align_Q_" << (align_Q ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
const int NQ = (qL + bq - 1) / bq;
const int NK = (kL + bk - 1) / bk;
const int NQ_aligned = qL / bq;
const int NK_aligned = kL / bk;
AttnParams params{
/* int B = */ B,
/* int H = */ H,
/* int L = */ L,
/* int D = */ D,
/* int qL = */ qL,
/* int kL = */ kL,
/* int gqa_factor = */ gqa_factor,
/* float scale = */ scale,
/* int NQ = */ NQ,
/* int NK = */ NK,
/* int NQ_aligned = */ NQ_aligned,
/* int NK_aligned = */ NK_aligned,
/* size_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
/* size_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
/* size_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},