From 005e7efa647dcb7a73e77dd62d3467d54176a68e Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 20 Mar 2025 14:53:12 -0700 Subject: [PATCH] fix mask in sdpa (#1980) * fix mask in sdpa * fix attention mask * Re-enable routing for array mask --------- Co-authored-by: Jagrit Digani --- .../metal/scaled_dot_product_attention.cpp | 6 +-- mlx/fast.cpp | 41 +++++++------------ python/tests/test_fast_sdpa.py | 16 ++++++++ 3 files changed, 34 insertions(+), 29 deletions(-) diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index f7ec004a66..c5b544852c 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -1,5 +1,4 @@ // Copyright © 2024 Apple Inc. - #include #include "mlx/backend/common/compiled.h" @@ -59,7 +58,7 @@ void sdpa_full_self_attention_metal( << "_bq" << bq << "_bk" << bk << "_bd" << bd - << "_wm" << wm + << "_wm" << wm << "_wn" << wn << "_mask" << (type_to_name(has_mask ? *mask : q)); // clang-format on @@ -67,7 +66,7 @@ void sdpa_full_self_attention_metal( // clang-format off kname << "_align_Q_" << (align_Q ? 't' : 'n') - << "_align_K_" << (align_K ? 't' : 'n') + << "_align_K_" << (align_K ? 't' : 'n') << "_has_mask_" << (has_mask ? 't' : 'n') << "_do_causal_" << (do_causal ? 't' : 'n'); // clang-format on @@ -117,6 +116,7 @@ void sdpa_full_self_attention_metal( if (mask) { auto m = *mask; + AttnMaskParams mask_params{/* int64_t M_strides[3] = */ { m.strides(0), m.strides(1), m.strides(2)}}; diff --git a/mlx/fast.cpp b/mlx/fast.cpp index ed0d9fbe53..ac3cfe042d 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -650,29 +650,6 @@ array scaled_dot_product_attention( throw std::invalid_argument(msg.str()); } - if (has_arr_mask) { - // Check type - auto mask_arr = std::get(mask); - has_bool_mask = mask_arr.dtype() == bool_; - if (promote_types(mask_arr.dtype(), final_type) != final_type) { - std::ostringstream msg; - msg << "[scaled_dot_product_attention] Mask type must promote to output type. " - << final_type << "."; - throw std::invalid_argument(msg.str()); - } - // Check shape - auto mask_shape = queries.shape(); - mask_shape.back() = keys.shape(-2); - if (broadcast_shapes(mask_arr.shape(), mask_shape) != mask_shape) { - std::ostringstream msg; - msg << "[scaled_dot_product_attention] Mask with shape " - << mask_arr.shape() - << " does not broadcast to implicit scores with shape " << mask_shape - << "."; - throw std::invalid_argument(msg.str()); - } - } - auto q = astype(queries, final_type, s); auto k = astype(keys, final_type, s); auto v = astype(values, final_type, s); @@ -748,8 +725,8 @@ array scaled_dot_product_attention( (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128); const bool sdpa_vector_supported_mask = (!has_mask || has_bool_mask); - const bool sdpa_full_supported_mask = - !has_mask || (query_sequence_length <= key_sequence_length && do_causal); + const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || + (query_sequence_length <= key_sequence_length && do_causal); const bool supports_sdpa_full = query_sequence_length >= threshold && sdpa_full_supported_mask && sdpa_full_supported_head_dim && @@ -765,7 +742,19 @@ array scaled_dot_product_attention( std::vector inputs = {q, k, v}; if (has_arr_mask) { - inputs.push_back(std::get(mask)); + // Check type + auto mask_arr = std::get(mask); + has_bool_mask = mask_arr.dtype() == bool_; + if (promote_types(mask_arr.dtype(), final_type) != final_type) { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] Mask type must promote to output type. " + << final_type << "."; + throw std::invalid_argument(msg.str()); + } + // Broadcast mask + auto mask_shape = queries.shape(); + mask_shape.back() = keys.shape(-2); + inputs.push_back(broadcast_to(mask_arr, mask_shape, stream)); } if (implementation_supports_use_case) { auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}; diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 4ea573564a..78e03159fa 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -527,6 +527,22 @@ class TestSDPA(mlx_tests.MLXTestCase): diff = mx.abs(out_fst - out_ref) - atol * mx.abs(out_ref) self.assertLessEqual(mx.max(diff).item(), atol) + def test_sdpa_broadcast_mask(self): + mask = mx.array(True) + D = 64 + Nq = 4 + Nkv = 1 + scale = 1.0 + L = 256 + + mx.random.seed(0) + q = 5e-1 * mx.random.normal(shape=(1, Nq, L, D)) + k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) + v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) + ref = mlx_primitives_sdpa(q, k, v, scale, mask=mask) + out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + if __name__ == "__main__": unittest.main(failfast=True)