mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-12 19:11:19 +08:00
[WIP] Add support for unaligned seq lengths - still looks messy
This commit is contained in:
parent
c1dc852995
commit
83c4f6bde6
@ -6,18 +6,8 @@ using namespace mlx::steel;
|
|||||||
// GEMM kernels
|
// GEMM kernels
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
constant bool has_batch [[function_constant(10)]];
|
constant bool align_Q [[function_constant(200)]];
|
||||||
|
constant bool align_K [[function_constant(201)]];
|
||||||
constant bool use_out_source [[function_constant(100)]];
|
|
||||||
constant bool do_axpby [[function_constant(110)]];
|
|
||||||
|
|
||||||
constant bool align_M [[function_constant(200)]];
|
|
||||||
constant bool align_N [[function_constant(201)]];
|
|
||||||
constant bool align_K [[function_constant(202)]];
|
|
||||||
|
|
||||||
constant bool do_gather [[function_constant(300)]];
|
|
||||||
|
|
||||||
constant bool gather_bias = do_gather && use_out_source;
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct TransformScale {
|
struct TransformScale {
|
||||||
@ -204,7 +194,11 @@ template <
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// Load Q blocks apply scale
|
// Load Q blocks apply scale
|
||||||
loader_q.load_unsafe();
|
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
|
||||||
|
loader_q.load_safe(short2(BD, params->qL - params->NQ_aligned * BQ));
|
||||||
|
} else {
|
||||||
|
loader_q.load_unsafe();
|
||||||
|
}
|
||||||
loader_q.apply_inplace_op(ts);
|
loader_q.apply_inplace_op(ts);
|
||||||
|
|
||||||
// Init row reduction variables
|
// Init row reduction variables
|
||||||
@ -223,7 +217,11 @@ template <
|
|||||||
for (int kb = 0; kb < params->NK; kb++) {
|
for (int kb = 0; kb < params->NK; kb++) {
|
||||||
// Load K block and apply scale
|
// Load K block and apply scale
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
loader_k.load_unsafe();
|
if (!align_K && kb == (params->NK_aligned)) {
|
||||||
|
loader_k.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
|
||||||
|
} else {
|
||||||
|
loader_k.load_unsafe();
|
||||||
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
@ -243,10 +241,36 @@ template <
|
|||||||
tile_matmad(Stile, Qtile, Ktile, Stile);
|
tile_matmad(Stile, Qtile, Ktile, Stile);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Mask out of length sequence
|
||||||
|
if (!align_K && kb == (params->NK_aligned)) {
|
||||||
|
using stile_t = decltype(Stile);
|
||||||
|
using selem_t = typename stile_t::elem_type;
|
||||||
|
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
|
||||||
|
const short lim = params->kL - params->NK_aligned * BK;
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < stile_t::kTileRows; i++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < stile_t::kTileCols; j++) {
|
||||||
|
short col_pos = sn + (j * stile_t::kFragCols);
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
|
||||||
|
if ((col_pos + jj) >= lim) {
|
||||||
|
Stile.frag_at(i, j)[jj] = neg_inf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
// Load V blocks
|
// Load V blocks
|
||||||
loader_v.load_unsafe();
|
if (!align_K && kb == (params->NK_aligned)) {
|
||||||
|
loader_v.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
|
||||||
|
} else {
|
||||||
|
loader_v.load_unsafe();
|
||||||
|
}
|
||||||
|
|
||||||
// Do softmax
|
// Do softmax
|
||||||
|
|
||||||
@ -309,5 +333,16 @@ template <
|
|||||||
|
|
||||||
// Store results
|
// Store results
|
||||||
O += (tm + sm) * params->O_strides[2] + sn;
|
O += (tm + sm) * params->O_strides[2] + sn;
|
||||||
Otile.template store<T, 1, 1>(O, params->O_strides[2]);
|
|
||||||
|
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
|
||||||
|
auto dst_tile_dims =
|
||||||
|
short2(BD - sn, params->qL - BQ * params->NQ_aligned - (tm + sm));
|
||||||
|
|
||||||
|
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||||
|
return;
|
||||||
|
|
||||||
|
Otile.template store_safe<T, 1, 1>(O, params->O_strides[2], dst_tile_dims);
|
||||||
|
} else {
|
||||||
|
Otile.template store<T, 1, 1>(O, params->O_strides[2]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -21,11 +21,9 @@
|
|||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
#define instantiate_attn_shapes_helper(iname, itype) \
|
#define instantiate_attn_shapes_helper(iname, itype) \
|
||||||
instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \
|
instantiate_attn(iname, itype, 32, 32, 128, 4, 1) \
|
||||||
instantiate_attn(iname, itype, 32, 16, 80, 4, 1) \
|
instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \
|
||||||
instantiate_attn(iname, itype, 32, 32, 64, 4, 1) \
|
instantiate_attn(iname, itype, 32, 32, 64, 4, 1)
|
||||||
instantiate_attn(iname, itype, 32, 16, 64, 4, 1) \
|
|
||||||
|
|
||||||
|
|
||||||
instantiate_attn_shapes_helper(float16, half);
|
instantiate_attn_shapes_helper(float16, half);
|
||||||
// instantiate_attn_shapes_helper(bfloat16, bfloat16_t);
|
// instantiate_attn_shapes_helper(bfloat16, bfloat16_t);
|
||||||
|
@ -12,15 +12,20 @@ namespace steel {
|
|||||||
struct AttnParams {
|
struct AttnParams {
|
||||||
int B; ///< Batch Size
|
int B; ///< Batch Size
|
||||||
int H; ///< Heads
|
int H; ///< Heads
|
||||||
int L; ///< Sequence Length
|
|
||||||
int D; ///< Head Dim
|
int D; ///< Head Dim
|
||||||
|
|
||||||
|
int qL; ///< Query Sequence Length
|
||||||
|
int kL; ///< Key Sequence Length
|
||||||
|
|
||||||
int gqa_factor; ///< Group Query factor
|
int gqa_factor; ///< Group Query factor
|
||||||
float scale; ///< Attention scale
|
float scale; ///< Attention scale
|
||||||
|
|
||||||
int NQ; ///< Number of query blocks
|
int NQ; ///< Number of query blocks
|
||||||
int NK; ///< Number of key/value blocks
|
int NK; ///< Number of key/value blocks
|
||||||
|
|
||||||
|
int NQ_aligned; ///< Number of full query blocks
|
||||||
|
int NK_aligned; ///< Number of full key/value blocks
|
||||||
|
|
||||||
size_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
|
size_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
|
||||||
size_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
|
size_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
|
||||||
size_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
|
size_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
|
||||||
|
@ -33,34 +33,66 @@ void sdpa_full_self_attention_metal(
|
|||||||
int bk = 32;
|
int bk = 32;
|
||||||
int bd = q.shape(-1);
|
int bd = q.shape(-1);
|
||||||
|
|
||||||
std::ostringstream kname;
|
|
||||||
kname << "steel_attention_" << type_to_name(q) << "_bq" << bq << "_bk" << bk
|
|
||||||
<< "_bd" << bd << "_wm" << wm << "_wn" << wn;
|
|
||||||
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
||||||
auto kernel = d.get_kernel(kname.str());
|
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
|
||||||
|
|
||||||
int B = q.shape(0);
|
int B = q.shape(0);
|
||||||
int H = q.shape(1);
|
int H = q.shape(1);
|
||||||
int L = q.shape(2);
|
|
||||||
int D = q.shape(3);
|
int D = q.shape(3);
|
||||||
int gqa_factor = q.shape(1) / k.shape(1);
|
int gqa_factor = q.shape(1) / k.shape(1);
|
||||||
|
|
||||||
int NQ = (L + bq - 1) / bq;
|
int qL = q.shape(2);
|
||||||
int NK = (L + bk - 1) / bk;
|
int kL = k.shape(2);
|
||||||
|
|
||||||
|
const bool align_Q = (qL % bq) == 0;
|
||||||
|
const bool align_K = (kL % bk) == 0;
|
||||||
|
|
||||||
|
metal::MTLFCList func_consts = {
|
||||||
|
{&align_Q, MTL::DataType::DataTypeBool, 200},
|
||||||
|
{&align_K, MTL::DataType::DataTypeBool, 201},
|
||||||
|
};
|
||||||
|
|
||||||
|
std::ostringstream kname;
|
||||||
|
// clang-format off
|
||||||
|
kname << "steel_attention_"
|
||||||
|
<< type_to_name(q)
|
||||||
|
<< "_bq" << bq
|
||||||
|
<< "_bk" << bk
|
||||||
|
<< "_bd" << bd
|
||||||
|
<< "_wm" << wm << "_wn" << wn; // clang-format on
|
||||||
|
|
||||||
|
std::string base_name = kname.str();
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
kname << "_align_Q_" << (align_Q ? 't' : 'n')
|
||||||
|
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
|
||||||
|
|
||||||
|
std::string hash_name = kname.str();
|
||||||
|
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
|
const int NQ = (qL + bq - 1) / bq;
|
||||||
|
const int NK = (kL + bk - 1) / bk;
|
||||||
|
|
||||||
|
const int NQ_aligned = qL / bq;
|
||||||
|
const int NK_aligned = kL / bk;
|
||||||
|
|
||||||
AttnParams params{
|
AttnParams params{
|
||||||
/* int B = */ B,
|
/* int B = */ B,
|
||||||
/* int H = */ H,
|
/* int H = */ H,
|
||||||
/* int L = */ L,
|
|
||||||
/* int D = */ D,
|
/* int D = */ D,
|
||||||
|
|
||||||
|
/* int qL = */ qL,
|
||||||
|
/* int kL = */ kL,
|
||||||
|
|
||||||
/* int gqa_factor = */ gqa_factor,
|
/* int gqa_factor = */ gqa_factor,
|
||||||
/* float scale = */ scale,
|
/* float scale = */ scale,
|
||||||
|
|
||||||
/* int NQ = */ NQ,
|
/* int NQ = */ NQ,
|
||||||
/* int NK = */ NK,
|
/* int NK = */ NK,
|
||||||
|
|
||||||
|
/* int NQ_aligned = */ NQ_aligned,
|
||||||
|
/* int NK_aligned = */ NK_aligned,
|
||||||
|
|
||||||
/* size_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
|
/* size_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
|
||||||
/* size_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
|
/* size_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
|
||||||
/* size_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
|
/* size_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
|
||||||
|
Loading…
Reference in New Issue
Block a user