fix mask in sdpa (#1980)

* fix mask in sdpa

* fix attention mask

* Re-enable routing for array mask

---------

Co-authored-by: Jagrit Digani <digani@apple.com>
This commit is contained in:
Awni Hannun 2025-03-20 14:53:12 -07:00 committed by GitHub
parent b42d13ec84
commit 005e7efa64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 34 additions and 29 deletions

View File

@ -1,5 +1,4 @@
// Copyright © 2024 Apple Inc.
#include <sstream>
#include "mlx/backend/common/compiled.h"
@ -59,7 +58,7 @@ void sdpa_full_self_attention_metal(
<< "_bq" << bq
<< "_bk" << bk
<< "_bd" << bd
<< "_wm" << wm
<< "_wm" << wm
<< "_wn" << wn
<< "_mask" << (type_to_name(has_mask ? *mask : q)); // clang-format on
@ -67,7 +66,7 @@ void sdpa_full_self_attention_metal(
// clang-format off
kname << "_align_Q_" << (align_Q ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n')
<< "_has_mask_" << (has_mask ? 't' : 'n')
<< "_do_causal_" << (do_causal ? 't' : 'n'); // clang-format on
@ -117,6 +116,7 @@ void sdpa_full_self_attention_metal(
if (mask) {
auto m = *mask;
AttnMaskParams mask_params{/* int64_t M_strides[3] = */ {
m.strides(0), m.strides(1), m.strides(2)}};

View File

@ -650,29 +650,6 @@ array scaled_dot_product_attention(
throw std::invalid_argument(msg.str());
}
if (has_arr_mask) {
// Check type
auto mask_arr = std::get<array>(mask);
has_bool_mask = mask_arr.dtype() == bool_;
if (promote_types(mask_arr.dtype(), final_type) != final_type) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
<< final_type << ".";
throw std::invalid_argument(msg.str());
}
// Check shape
auto mask_shape = queries.shape();
mask_shape.back() = keys.shape(-2);
if (broadcast_shapes(mask_arr.shape(), mask_shape) != mask_shape) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Mask with shape "
<< mask_arr.shape()
<< " does not broadcast to implicit scores with shape " << mask_shape
<< ".";
throw std::invalid_argument(msg.str());
}
}
auto q = astype(queries, final_type, s);
auto k = astype(keys, final_type, s);
auto v = astype(values, final_type, s);
@ -748,8 +725,8 @@ array scaled_dot_product_attention(
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
const bool sdpa_vector_supported_mask = (!has_mask || has_bool_mask);
const bool sdpa_full_supported_mask =
!has_mask || (query_sequence_length <= key_sequence_length && do_causal);
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
(query_sequence_length <= key_sequence_length && do_causal);
const bool supports_sdpa_full = query_sequence_length >= threshold &&
sdpa_full_supported_mask && sdpa_full_supported_head_dim &&
@ -765,7 +742,19 @@ array scaled_dot_product_attention(
std::vector<array> inputs = {q, k, v};
if (has_arr_mask) {
inputs.push_back(std::get<array>(mask));
// Check type
auto mask_arr = std::get<array>(mask);
has_bool_mask = mask_arr.dtype() == bool_;
if (promote_types(mask_arr.dtype(), final_type) != final_type) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
<< final_type << ".";
throw std::invalid_argument(msg.str());
}
// Broadcast mask
auto mask_shape = queries.shape();
mask_shape.back() = keys.shape(-2);
inputs.push_back(broadcast_to(mask_arr, mask_shape, stream));
}
if (implementation_supports_use_case) {
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};

View File

@ -527,6 +527,22 @@ class TestSDPA(mlx_tests.MLXTestCase):
diff = mx.abs(out_fst - out_ref) - atol * mx.abs(out_ref)
self.assertLessEqual(mx.max(diff).item(), atol)
def test_sdpa_broadcast_mask(self):
mask = mx.array(True)
D = 64
Nq = 4
Nkv = 1
scale = 1.0
L = 256
mx.random.seed(0)
q = 5e-1 * mx.random.normal(shape=(1, Nq, L, D))
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
ref = mlx_primitives_sdpa(q, k, v, scale, mask=mask)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
if __name__ == "__main__":
unittest.main(failfast=True)