mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-09 18:56:39 +08:00
fix mask in sdpa (#1980)
* fix mask in sdpa * fix attention mask * Re-enable routing for array mask --------- Co-authored-by: Jagrit Digani <digani@apple.com>
This commit is contained in:
parent
b42d13ec84
commit
005e7efa64
@ -1,5 +1,4 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
@ -59,7 +58,7 @@ void sdpa_full_self_attention_metal(
|
||||
<< "_bq" << bq
|
||||
<< "_bk" << bk
|
||||
<< "_bd" << bd
|
||||
<< "_wm" << wm
|
||||
<< "_wm" << wm
|
||||
<< "_wn" << wn
|
||||
<< "_mask" << (type_to_name(has_mask ? *mask : q)); // clang-format on
|
||||
|
||||
@ -67,7 +66,7 @@ void sdpa_full_self_attention_metal(
|
||||
|
||||
// clang-format off
|
||||
kname << "_align_Q_" << (align_Q ? 't' : 'n')
|
||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
||||
<< "_has_mask_" << (has_mask ? 't' : 'n')
|
||||
<< "_do_causal_" << (do_causal ? 't' : 'n'); // clang-format on
|
||||
|
||||
@ -117,6 +116,7 @@ void sdpa_full_self_attention_metal(
|
||||
|
||||
if (mask) {
|
||||
auto m = *mask;
|
||||
|
||||
AttnMaskParams mask_params{/* int64_t M_strides[3] = */ {
|
||||
m.strides(0), m.strides(1), m.strides(2)}};
|
||||
|
||||
|
41
mlx/fast.cpp
41
mlx/fast.cpp
@ -650,29 +650,6 @@ array scaled_dot_product_attention(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (has_arr_mask) {
|
||||
// Check type
|
||||
auto mask_arr = std::get<array>(mask);
|
||||
has_bool_mask = mask_arr.dtype() == bool_;
|
||||
if (promote_types(mask_arr.dtype(), final_type) != final_type) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
|
||||
<< final_type << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
// Check shape
|
||||
auto mask_shape = queries.shape();
|
||||
mask_shape.back() = keys.shape(-2);
|
||||
if (broadcast_shapes(mask_arr.shape(), mask_shape) != mask_shape) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Mask with shape "
|
||||
<< mask_arr.shape()
|
||||
<< " does not broadcast to implicit scores with shape " << mask_shape
|
||||
<< ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
auto q = astype(queries, final_type, s);
|
||||
auto k = astype(keys, final_type, s);
|
||||
auto v = astype(values, final_type, s);
|
||||
@ -748,8 +725,8 @@ array scaled_dot_product_attention(
|
||||
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
|
||||
|
||||
const bool sdpa_vector_supported_mask = (!has_mask || has_bool_mask);
|
||||
const bool sdpa_full_supported_mask =
|
||||
!has_mask || (query_sequence_length <= key_sequence_length && do_causal);
|
||||
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
|
||||
(query_sequence_length <= key_sequence_length && do_causal);
|
||||
|
||||
const bool supports_sdpa_full = query_sequence_length >= threshold &&
|
||||
sdpa_full_supported_mask && sdpa_full_supported_head_dim &&
|
||||
@ -765,7 +742,19 @@ array scaled_dot_product_attention(
|
||||
|
||||
std::vector<array> inputs = {q, k, v};
|
||||
if (has_arr_mask) {
|
||||
inputs.push_back(std::get<array>(mask));
|
||||
// Check type
|
||||
auto mask_arr = std::get<array>(mask);
|
||||
has_bool_mask = mask_arr.dtype() == bool_;
|
||||
if (promote_types(mask_arr.dtype(), final_type) != final_type) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
|
||||
<< final_type << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
// Broadcast mask
|
||||
auto mask_shape = queries.shape();
|
||||
mask_shape.back() = keys.shape(-2);
|
||||
inputs.push_back(broadcast_to(mask_arr, mask_shape, stream));
|
||||
}
|
||||
if (implementation_supports_use_case) {
|
||||
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
||||
|
@ -527,6 +527,22 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
||||
diff = mx.abs(out_fst - out_ref) - atol * mx.abs(out_ref)
|
||||
self.assertLessEqual(mx.max(diff).item(), atol)
|
||||
|
||||
def test_sdpa_broadcast_mask(self):
|
||||
mask = mx.array(True)
|
||||
D = 64
|
||||
Nq = 4
|
||||
Nkv = 1
|
||||
scale = 1.0
|
||||
L = 256
|
||||
|
||||
mx.random.seed(0)
|
||||
q = 5e-1 * mx.random.normal(shape=(1, Nq, L, D))
|
||||
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||
ref = mlx_primitives_sdpa(q, k, v, scale, mask=mask)
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(failfast=True)
|
||||
|
Loading…
Reference in New Issue
Block a user