mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-10 19:26:42 +08:00
fix mask in sdpa (#1980)
* fix mask in sdpa * fix attention mask * Re-enable routing for array mask --------- Co-authored-by: Jagrit Digani <digani@apple.com>
This commit is contained in:
parent
b42d13ec84
commit
005e7efa64
@ -1,5 +1,4 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
@ -59,7 +58,7 @@ void sdpa_full_self_attention_metal(
|
|||||||
<< "_bq" << bq
|
<< "_bq" << bq
|
||||||
<< "_bk" << bk
|
<< "_bk" << bk
|
||||||
<< "_bd" << bd
|
<< "_bd" << bd
|
||||||
<< "_wm" << wm
|
<< "_wm" << wm
|
||||||
<< "_wn" << wn
|
<< "_wn" << wn
|
||||||
<< "_mask" << (type_to_name(has_mask ? *mask : q)); // clang-format on
|
<< "_mask" << (type_to_name(has_mask ? *mask : q)); // clang-format on
|
||||||
|
|
||||||
@ -67,7 +66,7 @@ void sdpa_full_self_attention_metal(
|
|||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
kname << "_align_Q_" << (align_Q ? 't' : 'n')
|
kname << "_align_Q_" << (align_Q ? 't' : 'n')
|
||||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
<< "_align_K_" << (align_K ? 't' : 'n')
|
||||||
<< "_has_mask_" << (has_mask ? 't' : 'n')
|
<< "_has_mask_" << (has_mask ? 't' : 'n')
|
||||||
<< "_do_causal_" << (do_causal ? 't' : 'n'); // clang-format on
|
<< "_do_causal_" << (do_causal ? 't' : 'n'); // clang-format on
|
||||||
|
|
||||||
@ -117,6 +116,7 @@ void sdpa_full_self_attention_metal(
|
|||||||
|
|
||||||
if (mask) {
|
if (mask) {
|
||||||
auto m = *mask;
|
auto m = *mask;
|
||||||
|
|
||||||
AttnMaskParams mask_params{/* int64_t M_strides[3] = */ {
|
AttnMaskParams mask_params{/* int64_t M_strides[3] = */ {
|
||||||
m.strides(0), m.strides(1), m.strides(2)}};
|
m.strides(0), m.strides(1), m.strides(2)}};
|
||||||
|
|
||||||
|
41
mlx/fast.cpp
41
mlx/fast.cpp
@ -650,29 +650,6 @@ array scaled_dot_product_attention(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (has_arr_mask) {
|
|
||||||
// Check type
|
|
||||||
auto mask_arr = std::get<array>(mask);
|
|
||||||
has_bool_mask = mask_arr.dtype() == bool_;
|
|
||||||
if (promote_types(mask_arr.dtype(), final_type) != final_type) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
|
|
||||||
<< final_type << ".";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
// Check shape
|
|
||||||
auto mask_shape = queries.shape();
|
|
||||||
mask_shape.back() = keys.shape(-2);
|
|
||||||
if (broadcast_shapes(mask_arr.shape(), mask_shape) != mask_shape) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[scaled_dot_product_attention] Mask with shape "
|
|
||||||
<< mask_arr.shape()
|
|
||||||
<< " does not broadcast to implicit scores with shape " << mask_shape
|
|
||||||
<< ".";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto q = astype(queries, final_type, s);
|
auto q = astype(queries, final_type, s);
|
||||||
auto k = astype(keys, final_type, s);
|
auto k = astype(keys, final_type, s);
|
||||||
auto v = astype(values, final_type, s);
|
auto v = astype(values, final_type, s);
|
||||||
@ -748,8 +725,8 @@ array scaled_dot_product_attention(
|
|||||||
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
|
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
|
||||||
|
|
||||||
const bool sdpa_vector_supported_mask = (!has_mask || has_bool_mask);
|
const bool sdpa_vector_supported_mask = (!has_mask || has_bool_mask);
|
||||||
const bool sdpa_full_supported_mask =
|
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
|
||||||
!has_mask || (query_sequence_length <= key_sequence_length && do_causal);
|
(query_sequence_length <= key_sequence_length && do_causal);
|
||||||
|
|
||||||
const bool supports_sdpa_full = query_sequence_length >= threshold &&
|
const bool supports_sdpa_full = query_sequence_length >= threshold &&
|
||||||
sdpa_full_supported_mask && sdpa_full_supported_head_dim &&
|
sdpa_full_supported_mask && sdpa_full_supported_head_dim &&
|
||||||
@ -765,7 +742,19 @@ array scaled_dot_product_attention(
|
|||||||
|
|
||||||
std::vector<array> inputs = {q, k, v};
|
std::vector<array> inputs = {q, k, v};
|
||||||
if (has_arr_mask) {
|
if (has_arr_mask) {
|
||||||
inputs.push_back(std::get<array>(mask));
|
// Check type
|
||||||
|
auto mask_arr = std::get<array>(mask);
|
||||||
|
has_bool_mask = mask_arr.dtype() == bool_;
|
||||||
|
if (promote_types(mask_arr.dtype(), final_type) != final_type) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
|
||||||
|
<< final_type << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
// Broadcast mask
|
||||||
|
auto mask_shape = queries.shape();
|
||||||
|
mask_shape.back() = keys.shape(-2);
|
||||||
|
inputs.push_back(broadcast_to(mask_arr, mask_shape, stream));
|
||||||
}
|
}
|
||||||
if (implementation_supports_use_case) {
|
if (implementation_supports_use_case) {
|
||||||
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
||||||
|
@ -527,6 +527,22 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
|||||||
diff = mx.abs(out_fst - out_ref) - atol * mx.abs(out_ref)
|
diff = mx.abs(out_fst - out_ref) - atol * mx.abs(out_ref)
|
||||||
self.assertLessEqual(mx.max(diff).item(), atol)
|
self.assertLessEqual(mx.max(diff).item(), atol)
|
||||||
|
|
||||||
|
def test_sdpa_broadcast_mask(self):
|
||||||
|
mask = mx.array(True)
|
||||||
|
D = 64
|
||||||
|
Nq = 4
|
||||||
|
Nkv = 1
|
||||||
|
scale = 1.0
|
||||||
|
L = 256
|
||||||
|
|
||||||
|
mx.random.seed(0)
|
||||||
|
q = 5e-1 * mx.random.normal(shape=(1, Nq, L, D))
|
||||||
|
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||||
|
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||||
|
ref = mlx_primitives_sdpa(q, k, v, scale, mask=mask)
|
||||||
|
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||||
|
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main(failfast=True)
|
unittest.main(failfast=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user