Fused attention for single query (#1497)

This commit is contained in:
Angelos Katharopoulos
2024-10-18 00:58:52 -07:00
committed by GitHub
parent 9dd72cd421
commit 50d8bed468
6 changed files with 299 additions and 742 deletions

View File

@@ -618,40 +618,38 @@ array scaled_dot_product_attention(
};
auto stream = to_stream(s);
const size_t value_head_dim = v.shape(-1);
const size_t query_head_dim = q.shape(-1);
const bool supported_head_dim =
query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128;
const bool supported_head_dim_self_attn =
query_head_dim == 64 || query_head_dim == 128;
const size_t query_sequence_length = q.shape(2);
const bool supports_full_self_attention = query_sequence_length >= 16 &&
!mask.has_value() && supported_head_dim_self_attn &&
bool implementation_supports_use_case = query_head_dim == value_head_dim;
const bool sdpa_vector_supported_head_dim =
query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128;
const bool sdpa_full_supported_head_dim =
query_head_dim == 64 || query_head_dim == 128;
const bool supports_sdpa_full = query_sequence_length >= threshold &&
!mask.has_value() && sdpa_full_supported_head_dim &&
n_q_heads == n_kv_heads && final_type != bfloat16 &&
stream.device == Device::gpu;
// fast decoding gpu shader
bool supports_sdpa = batch_dim == 1 && query_sequence_length == 1 &&
!mask.has_value() && supported_head_dim && final_type != bfloat16 &&
const bool supports_sdpa_vector = query_sequence_length == 1 &&
!mask.has_value() && sdpa_vector_supported_head_dim &&
stream.device == Device::gpu;
bool implementation_supports_use_case =
supports_sdpa || supports_full_self_attention;
// sdpa gpu shader is disabled except for memory efficient opt-in
const int seq_for_threshold = queries.shape(2);
bool use_memory_efficient_impl = seq_for_threshold >= threshold;
implementation_supports_use_case &= use_memory_efficient_impl;
implementation_supports_use_case &=
supports_sdpa_full || supports_sdpa_vector;
if (implementation_supports_use_case) {
auto out_shape =
std::vector<int>({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)});
auto out = array(
return array(
std::move(out_shape),
final_type,
std::make_shared<ScaledDotProductAttention>(
stream, fallback, scale, false),
{q, k, v});
return out;
}
if (mask.has_value()) {