Allow different value dimensions in sdpa_vector (#1811)

This commit is contained in:
Angelos Katharopoulos
2025-01-31 20:58:59 -08:00
committed by GitHub
parent b7c9f1d38f
commit f5cc1eea72
6 changed files with 127 additions and 72 deletions

View File

@@ -684,23 +684,20 @@ array scaled_dot_product_attention(
const size_t query_head_dim = q.shape(-1);
const size_t query_sequence_length = q.shape(2);
bool implementation_supports_use_case = query_head_dim == value_head_dim;
const bool sdpa_vector_supported_head_dim =
query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128;
const bool sdpa_full_supported_head_dim =
query_head_dim == 64 || query_head_dim == 80;
query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 80);
const bool supports_sdpa_full = query_sequence_length >= threshold &&
!mask.has_value() && sdpa_full_supported_head_dim &&
stream.device == Device::gpu;
const bool supports_sdpa_full = query_sequence_length >= threshold && !mask &&
sdpa_full_supported_head_dim && stream.device == Device::gpu;
const bool supported_mask = !mask || (mask->dtype() == bool_);
const bool supports_sdpa_vector = query_sequence_length == 1 &&
supported_mask && sdpa_vector_supported_head_dim &&
(!mask || mask->dtype() == bool_) && sdpa_vector_supported_head_dim &&
stream.device == Device::gpu;
implementation_supports_use_case &=
const bool implementation_supports_use_case =
supports_sdpa_full || supports_sdpa_vector;
std::vector<array> inputs = {q, k, v};