mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
Support fused masking in Attention (#1924)
* Update API to allow mask='causal' in fast::sdpa * Add fallback * Update steel::AttnParams * Fix typo * WIP, basic causal * Update tests * Update benchmarking * Update masking loop limits * Add bool masking and update tests * Update additive mask * Update benchmarks * Update benchmarks * Update tests * Update for bfloat error * Update early exit * Add random seed to tests
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
// Copyright © 2024-25 Apple Inc.
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
@@ -9,6 +9,9 @@ using namespace mlx::steel;
|
||||
constant bool align_Q [[function_constant(200)]];
|
||||
constant bool align_K [[function_constant(201)]];
|
||||
|
||||
constant bool has_mask [[function_constant(300)]];
|
||||
constant bool do_causal [[function_constant(301)]];
|
||||
|
||||
template <typename T>
|
||||
struct TransformScale {
|
||||
T scale;
|
||||
@@ -69,6 +72,7 @@ template <
|
||||
int BD,
|
||||
int WM,
|
||||
int WN,
|
||||
typename MaskType = float,
|
||||
typename AccumType = float>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention(
|
||||
const device T* Q [[buffer(0)]],
|
||||
@@ -76,6 +80,8 @@ template <
|
||||
const device T* V [[buffer(2)]],
|
||||
device T* O [[buffer(3)]],
|
||||
const constant AttnParams* params [[buffer(4)]],
|
||||
const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],
|
||||
const device MaskType* mask [[buffer(6), function_constant(has_mask)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
@@ -102,6 +108,11 @@ template <
|
||||
tidl.y * params->O_strides[1] + // Head
|
||||
tidl.x * BQ * params->O_strides[2]; // Seqeunce
|
||||
|
||||
if (has_mask) {
|
||||
mask += tidl.z * mask_params->M_strides[0] + // Batch
|
||||
tidl.y * mask_params->M_strides[1]; // Head
|
||||
}
|
||||
|
||||
// Prepare threadgroup memory
|
||||
constexpr short padQ = 16 / sizeof(T);
|
||||
constexpr short padK = 16 / sizeof(T);
|
||||
@@ -203,7 +214,7 @@ template <
|
||||
|
||||
// Load Q blocks apply scale
|
||||
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
|
||||
loader_q.load_safe(short2(BD, params->qL - params->NQ_aligned * BQ));
|
||||
loader_q.load_safe(short2(BD, params->qL_rem));
|
||||
} else {
|
||||
loader_q.load_unsafe();
|
||||
}
|
||||
@@ -221,12 +232,19 @@ template <
|
||||
max_score[i] = Limits<AccumType>::min;
|
||||
}
|
||||
|
||||
int kb_lim = params->NK;
|
||||
|
||||
if (do_causal) {
|
||||
int q_max = (tid.x + 1) * BQ + params->qL_off;
|
||||
kb_lim = (q_max + BK - 1) / BK;
|
||||
}
|
||||
|
||||
// Loop over KV seq length
|
||||
for (int kb = 0; kb < params->NK; kb++) {
|
||||
for (int kb = 0; kb < kb_lim; kb++) {
|
||||
// Load K block and apply scale
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (!align_K && kb == (params->NK_aligned)) {
|
||||
loader_k.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
|
||||
loader_k.load_safe(short2(BD, params->kL_rem));
|
||||
} else {
|
||||
loader_k.load_unsafe();
|
||||
}
|
||||
@@ -250,12 +268,11 @@ template <
|
||||
tile_matmad(Stile, Qtile, Ktile, Stile);
|
||||
}
|
||||
|
||||
// Mask out of length sequence
|
||||
// Mask out length sequence
|
||||
if (!align_K && kb == (params->NK_aligned)) {
|
||||
using stile_t = decltype(Stile);
|
||||
using selem_t = typename stile_t::elem_type;
|
||||
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
|
||||
const short lim = params->kL - params->NK_aligned * BK;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < stile_t::kTileRows; i++) {
|
||||
@@ -264,7 +281,7 @@ template <
|
||||
short col_pos = sn + (j * stile_t::kFragCols);
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
|
||||
if ((col_pos + jj) >= lim) {
|
||||
if ((col_pos + jj) >= params->kL_rem) {
|
||||
Stile.frag_at(i, j)[jj] = neg_inf;
|
||||
}
|
||||
}
|
||||
@@ -272,11 +289,78 @@ template <
|
||||
}
|
||||
}
|
||||
|
||||
// Mask out if causal
|
||||
if (do_causal && kb >= (kb_lim - (BQ + BK - 1) / BK - int(!align_K))) {
|
||||
using stile_t = decltype(Stile);
|
||||
using selem_t = typename stile_t::elem_type;
|
||||
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < stile_t::kTileRows; i++) {
|
||||
const int row_pos =
|
||||
tid.x * BQ + params->qL_off + tm + sm + (i * stile_t::kFragRows);
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < stile_t::kTileCols; j++) {
|
||||
const int col_pos = kb * BK + sn + (j * stile_t::kFragCols);
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
|
||||
if (row_pos < (col_pos + jj)) {
|
||||
Stile.frag_at(i, j)[jj] = neg_inf;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Other masking as needed
|
||||
if (has_mask) {
|
||||
using stile_t = decltype(Stile);
|
||||
using selem_t = typename stile_t::elem_type;
|
||||
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
|
||||
|
||||
constexpr bool is_bool = is_same_v<MaskType, bool>;
|
||||
using melem_t = typename metal::conditional_t<is_bool, bool, selem_t>;
|
||||
|
||||
using MMAFrag_mask_t = BaseMMAFrag<melem_t, kFragSize, kFragSize>;
|
||||
using frag_t = typename MMAFrag_mask_t::frag_type;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < stile_t::kTileRows; i++) {
|
||||
const int row_pos = tid.x * BQ + tm + sm + (i * stile_t::kFragRows);
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < stile_t::kTileCols; j++) {
|
||||
const int col_pos = kb * BK + sn + (j * stile_t::kFragCols);
|
||||
|
||||
frag_t mfrag;
|
||||
|
||||
MMAFrag_mask_t::load_safe(
|
||||
mfrag,
|
||||
mask,
|
||||
int(mask_params->M_strides[2]),
|
||||
Int<1>{},
|
||||
params->qL,
|
||||
params->kL,
|
||||
row_pos,
|
||||
col_pos);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemsPerFrag; jj++) {
|
||||
if constexpr (is_bool) {
|
||||
Stile.frag_at(i, j)[jj] =
|
||||
mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf;
|
||||
} else {
|
||||
Stile.frag_at(i, j)[jj] += 1.44269504089 * selem_t(mfrag[jj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load V blocks
|
||||
if (!align_K && kb == (params->NK_aligned)) {
|
||||
loader_v.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
|
||||
loader_v.load_safe(short2(BD, params->kL_rem));
|
||||
} else {
|
||||
loader_v.load_unsafe();
|
||||
}
|
||||
@@ -367,8 +451,7 @@ template <
|
||||
O += (tm + sm) * params->O_strides[2] + sn;
|
||||
|
||||
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
|
||||
auto dst_tile_dims =
|
||||
short2(BD - sn, params->qL - BQ * params->NQ_aligned - (tm + sm));
|
||||
auto dst_tile_dims = short2(BD - sn, params->qL_rem - (tm + sm));
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
// Copyright © 2024-25 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
@@ -6,26 +6,23 @@
|
||||
#include "mlx/backend/metal/kernels/steel/attn/attn.h"
|
||||
#include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h"
|
||||
|
||||
#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn) \
|
||||
template [[host_name("steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd "_wm" #wm "_wn" #wn)]] \
|
||||
[[kernel]] void attention<dtype, bq, bk, bd, wm, wn, float>( \
|
||||
const device dtype* Q [[buffer(0)]], \
|
||||
const device dtype* K [[buffer(1)]], \
|
||||
const device dtype* V [[buffer(2)]], \
|
||||
device dtype* O [[buffer(3)]],\
|
||||
const constant AttnParams* params [[buffer(4)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn, mname, mtype) \
|
||||
instantiate_kernel( \
|
||||
"steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd \
|
||||
"_wm" #wm "_wn" #wn "_mask" #mname, \
|
||||
attention, dtype, bq, bk, bd, wm, wn, mtype, float)
|
||||
|
||||
#define instantiate_attn_shapes_helper(iname, itype) \
|
||||
instantiate_attn(iname, itype, 32, 16, 128, 4, 1) \
|
||||
instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \
|
||||
instantiate_attn(iname, itype, 32, 32, 64, 4, 1)
|
||||
#define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \
|
||||
instantiate_attn(iname, itype, 32, 16, 128, 4, 1, mname, mtype) \
|
||||
instantiate_attn(iname, itype, 32, 32, 80, 4, 1, mname, mtype) \
|
||||
instantiate_attn(iname, itype, 32, 32, 64, 4, 1, mname, mtype)
|
||||
|
||||
instantiate_attn_shapes_helper(float16, half);
|
||||
instantiate_attn_shapes_helper(bfloat16, bfloat16_t);
|
||||
#define instantiate_attn_mask_helper(iname, itype) \
|
||||
instantiate_attn_shapes_helper(iname, itype, iname, itype) \
|
||||
instantiate_attn_shapes_helper(iname, itype, bool_, bool)
|
||||
|
||||
instantiate_attn_shapes_helper(float32, float);
|
||||
instantiate_attn_mask_helper(float16, half);
|
||||
instantiate_attn_mask_helper(bfloat16, bfloat16_t);
|
||||
|
||||
instantiate_attn_mask_helper(float32, float);
|
||||
// clang-format on
|
||||
|
@@ -111,7 +111,7 @@ struct BaseMMAFrag<T, 8, 8> {
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
|
||||
dst[i * kElemCols + j] =
|
||||
static_cast<T>(src[(off_x + i) * str_x + (off_x + j) * str_y]);
|
||||
static_cast<T>(src[(off_x + i) * str_x + (off_y + j) * str_y]);
|
||||
} else {
|
||||
dst[i * kElemCols + j] = T(0);
|
||||
}
|
||||
|
@@ -26,11 +26,19 @@ struct AttnParams {
|
||||
int NQ_aligned; ///< Number of full query blocks
|
||||
int NK_aligned; ///< Number of full key/value blocks
|
||||
|
||||
int qL_rem; ///< Remainder in last query block
|
||||
int kL_rem; ///< Remainder in last key/value block
|
||||
int qL_off; ///< Offset in query sequence start
|
||||
|
||||
int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
|
||||
int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
|
||||
int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
|
||||
int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1)
|
||||
};
|
||||
|
||||
struct AttnMaskParams {
|
||||
int64_t M_strides[3]; ///< Mask strides (B, H, qL, kL = 1)
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
||||
|
@@ -21,7 +21,9 @@ void sdpa_full_self_attention_metal(
|
||||
const array& k,
|
||||
const array& v,
|
||||
const float scale,
|
||||
array& o) {
|
||||
array& o,
|
||||
bool do_causal_ = false,
|
||||
const std::optional<array>& mask = std::nullopt) {
|
||||
using namespace mlx::steel;
|
||||
|
||||
int wm = 4;
|
||||
@@ -41,11 +43,14 @@ void sdpa_full_self_attention_metal(
|
||||
|
||||
const bool align_Q = (qL % bq) == 0;
|
||||
const bool align_K = (kL % bk) == 0;
|
||||
const bool has_mask = !!mask;
|
||||
const bool do_causal = do_causal_;
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&align_Q, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_K, MTL::DataType::DataTypeBool, 201},
|
||||
};
|
||||
{&has_mask, MTL::DataType::DataTypeBool, 300},
|
||||
{&do_causal, MTL::DataType::DataTypeBool, 301}};
|
||||
|
||||
std::ostringstream kname;
|
||||
// clang-format off
|
||||
@@ -54,13 +59,17 @@ void sdpa_full_self_attention_metal(
|
||||
<< "_bq" << bq
|
||||
<< "_bk" << bk
|
||||
<< "_bd" << bd
|
||||
<< "_wm" << wm << "_wn" << wn; // clang-format on
|
||||
<< "_wm" << wm
|
||||
<< "_wn" << wn
|
||||
<< "_mask" << (type_to_name(has_mask ? *mask : q)); // clang-format on
|
||||
|
||||
std::string base_name = kname.str();
|
||||
|
||||
// clang-format off
|
||||
kname << "_align_Q_" << (align_Q ? 't' : 'n')
|
||||
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
|
||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
||||
<< "_has_mask_" << (has_mask ? 't' : 'n')
|
||||
<< "_do_causal_" << (do_causal ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
@@ -91,6 +100,10 @@ void sdpa_full_self_attention_metal(
|
||||
/* int NQ_aligned = */ NQ_aligned,
|
||||
/* int NK_aligned = */ NK_aligned,
|
||||
|
||||
/* int qL_rem = */ (qL - NQ_aligned * bq),
|
||||
/* int kL_rem = */ (kL - NK_aligned * bk),
|
||||
/* int qL_off = */ (kL - qL),
|
||||
|
||||
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
|
||||
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
|
||||
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
|
||||
@@ -102,6 +115,15 @@ void sdpa_full_self_attention_metal(
|
||||
compute_encoder.set_output_array(o, 3);
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
|
||||
if (mask) {
|
||||
auto m = *mask;
|
||||
AttnMaskParams mask_params{/* int64_t M_strides[3] = */ {
|
||||
m.strides(0), m.strides(1), m.strides(2)}};
|
||||
|
||||
compute_encoder.set_bytes(mask_params, 5);
|
||||
compute_encoder.set_input_array(m, 6);
|
||||
}
|
||||
|
||||
MTL::Size grid_dims = MTL::Size(NQ, H, B);
|
||||
MTL::Size group_dims = MTL::Size(32, wm, wn);
|
||||
|
||||
@@ -346,7 +368,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
|
||||
// Checks that the headdim dimension has stride 1.
|
||||
auto is_matrix_contiguous = [](const array& arr) {
|
||||
return arr.strides(3) == 1;
|
||||
return arr.strides(-1) == 1;
|
||||
};
|
||||
|
||||
// We are in vector mode ie single query
|
||||
@@ -415,7 +437,11 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
{str_oB, str_oH, str_oL, str_oD},
|
||||
flags);
|
||||
|
||||
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
|
||||
auto mask = inputs.size() > 3
|
||||
? std::optional<array>{copy_unless(is_matrix_contiguous, inputs[3])}
|
||||
: std::nullopt;
|
||||
|
||||
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o, do_causal_, mask);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
|
Reference in New Issue
Block a user