fix mask in sdpa (#1980)

* fix mask in sdpa

* fix attention mask

* Re-enable routing for array mask

---------

Co-authored-by: Jagrit Digani <digani@apple.com>
This commit is contained in:
Awni Hannun
2025-03-20 14:53:12 -07:00
committed by GitHub
parent b42d13ec84
commit 005e7efa64
3 changed files with 34 additions and 29 deletions

View File

@@ -1,5 +1,4 @@
// Copyright © 2024 Apple Inc.
#include <sstream>
#include "mlx/backend/common/compiled.h"
@@ -59,7 +58,7 @@ void sdpa_full_self_attention_metal(
<< "_bq" << bq
<< "_bk" << bk
<< "_bd" << bd
<< "_wm" << wm
<< "_wm" << wm
<< "_wn" << wn
<< "_mask" << (type_to_name(has_mask ? *mask : q)); // clang-format on
@@ -67,7 +66,7 @@ void sdpa_full_self_attention_metal(
// clang-format off
kname << "_align_Q_" << (align_Q ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n')
<< "_has_mask_" << (has_mask ? 't' : 'n')
<< "_do_causal_" << (do_causal ? 't' : 'n'); // clang-format on
@@ -117,6 +116,7 @@ void sdpa_full_self_attention_metal(
if (mask) {
auto m = *mask;
AttnMaskParams mask_params{/* int64_t M_strides[3] = */ {
m.strides(0), m.strides(1), m.strides(2)}};