mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Support fused masking in Attention (#1924)
* Update API to allow mask='causal' in fast::sdpa * Add fallback * Update steel::AttnParams * Fix typo * WIP, basic causal * Update tests * Update benchmarking * Update masking loop limits * Add bool masking and update tests * Update additive mask * Update benchmarks * Update benchmarks * Update tests * Update for bfloat error * Update early exit * Add random seed to tests
This commit is contained in:
@@ -21,7 +21,9 @@ void sdpa_full_self_attention_metal(
|
||||
const array& k,
|
||||
const array& v,
|
||||
const float scale,
|
||||
array& o) {
|
||||
array& o,
|
||||
bool do_causal_ = false,
|
||||
const std::optional<array>& mask = std::nullopt) {
|
||||
using namespace mlx::steel;
|
||||
|
||||
int wm = 4;
|
||||
@@ -41,11 +43,14 @@ void sdpa_full_self_attention_metal(
|
||||
|
||||
const bool align_Q = (qL % bq) == 0;
|
||||
const bool align_K = (kL % bk) == 0;
|
||||
const bool has_mask = !!mask;
|
||||
const bool do_causal = do_causal_;
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&align_Q, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_K, MTL::DataType::DataTypeBool, 201},
|
||||
};
|
||||
{&has_mask, MTL::DataType::DataTypeBool, 300},
|
||||
{&do_causal, MTL::DataType::DataTypeBool, 301}};
|
||||
|
||||
std::ostringstream kname;
|
||||
// clang-format off
|
||||
@@ -54,13 +59,17 @@ void sdpa_full_self_attention_metal(
|
||||
<< "_bq" << bq
|
||||
<< "_bk" << bk
|
||||
<< "_bd" << bd
|
||||
<< "_wm" << wm << "_wn" << wn; // clang-format on
|
||||
<< "_wm" << wm
|
||||
<< "_wn" << wn
|
||||
<< "_mask" << (type_to_name(has_mask ? *mask : q)); // clang-format on
|
||||
|
||||
std::string base_name = kname.str();
|
||||
|
||||
// clang-format off
|
||||
kname << "_align_Q_" << (align_Q ? 't' : 'n')
|
||||
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
|
||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
||||
<< "_has_mask_" << (has_mask ? 't' : 'n')
|
||||
<< "_do_causal_" << (do_causal ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
@@ -91,6 +100,10 @@ void sdpa_full_self_attention_metal(
|
||||
/* int NQ_aligned = */ NQ_aligned,
|
||||
/* int NK_aligned = */ NK_aligned,
|
||||
|
||||
/* int qL_rem = */ (qL - NQ_aligned * bq),
|
||||
/* int kL_rem = */ (kL - NK_aligned * bk),
|
||||
/* int qL_off = */ (kL - qL),
|
||||
|
||||
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
|
||||
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
|
||||
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
|
||||
@@ -102,6 +115,15 @@ void sdpa_full_self_attention_metal(
|
||||
compute_encoder.set_output_array(o, 3);
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
|
||||
if (mask) {
|
||||
auto m = *mask;
|
||||
AttnMaskParams mask_params{/* int64_t M_strides[3] = */ {
|
||||
m.strides(0), m.strides(1), m.strides(2)}};
|
||||
|
||||
compute_encoder.set_bytes(mask_params, 5);
|
||||
compute_encoder.set_input_array(m, 6);
|
||||
}
|
||||
|
||||
MTL::Size grid_dims = MTL::Size(NQ, H, B);
|
||||
MTL::Size group_dims = MTL::Size(32, wm, wn);
|
||||
|
||||
@@ -346,7 +368,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
|
||||
// Checks that the headdim dimension has stride 1.
|
||||
auto is_matrix_contiguous = [](const array& arr) {
|
||||
return arr.strides(3) == 1;
|
||||
return arr.strides(-1) == 1;
|
||||
};
|
||||
|
||||
// We are in vector mode ie single query
|
||||
@@ -415,7 +437,11 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
{str_oB, str_oH, str_oL, str_oD},
|
||||
flags);
|
||||
|
||||
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
|
||||
auto mask = inputs.size() > 3
|
||||
? std::optional<array>{copy_unless(is_matrix_contiguous, inputs[3])}
|
||||
: std::nullopt;
|
||||
|
||||
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o, do_causal_, mask);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
|
||||
Reference in New Issue
Block a user