mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fused attention for single query (#1497)
This commit is contained in:
committed by
GitHub
parent
9dd72cd421
commit
50d8bed468
34
mlx/fast.cpp
34
mlx/fast.cpp
@@ -618,40 +618,38 @@ array scaled_dot_product_attention(
|
||||
};
|
||||
|
||||
auto stream = to_stream(s);
|
||||
const size_t value_head_dim = v.shape(-1);
|
||||
const size_t query_head_dim = q.shape(-1);
|
||||
const bool supported_head_dim =
|
||||
query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128;
|
||||
|
||||
const bool supported_head_dim_self_attn =
|
||||
query_head_dim == 64 || query_head_dim == 128;
|
||||
const size_t query_sequence_length = q.shape(2);
|
||||
const bool supports_full_self_attention = query_sequence_length >= 16 &&
|
||||
!mask.has_value() && supported_head_dim_self_attn &&
|
||||
|
||||
bool implementation_supports_use_case = query_head_dim == value_head_dim;
|
||||
|
||||
const bool sdpa_vector_supported_head_dim =
|
||||
query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128;
|
||||
const bool sdpa_full_supported_head_dim =
|
||||
query_head_dim == 64 || query_head_dim == 128;
|
||||
|
||||
const bool supports_sdpa_full = query_sequence_length >= threshold &&
|
||||
!mask.has_value() && sdpa_full_supported_head_dim &&
|
||||
n_q_heads == n_kv_heads && final_type != bfloat16 &&
|
||||
stream.device == Device::gpu;
|
||||
|
||||
// fast decoding gpu shader
|
||||
bool supports_sdpa = batch_dim == 1 && query_sequence_length == 1 &&
|
||||
!mask.has_value() && supported_head_dim && final_type != bfloat16 &&
|
||||
const bool supports_sdpa_vector = query_sequence_length == 1 &&
|
||||
!mask.has_value() && sdpa_vector_supported_head_dim &&
|
||||
stream.device == Device::gpu;
|
||||
bool implementation_supports_use_case =
|
||||
supports_sdpa || supports_full_self_attention;
|
||||
|
||||
// sdpa gpu shader is disabled except for memory efficient opt-in
|
||||
const int seq_for_threshold = queries.shape(2);
|
||||
bool use_memory_efficient_impl = seq_for_threshold >= threshold;
|
||||
implementation_supports_use_case &= use_memory_efficient_impl;
|
||||
implementation_supports_use_case &=
|
||||
supports_sdpa_full || supports_sdpa_vector;
|
||||
|
||||
if (implementation_supports_use_case) {
|
||||
auto out_shape =
|
||||
std::vector<int>({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)});
|
||||
auto out = array(
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
final_type,
|
||||
std::make_shared<ScaledDotProductAttention>(
|
||||
stream, fallback, scale, false),
|
||||
{q, k, v});
|
||||
return out;
|
||||
}
|
||||
|
||||
if (mask.has_value()) {
|
||||
|
||||
Reference in New Issue
Block a user