SDPA support for small batch (over sequence) queries (#1922)

* batch query sdpa

* batch sdpa for query
This commit is contained in:
Awni Hannun
2025-03-04 10:59:04 -08:00
committed by GitHub
parent 6bcd6bcf70
commit e613d0eaf0
5 changed files with 159 additions and 45 deletions

View File

@@ -715,7 +715,8 @@ array scaled_dot_product_attention(
const bool supports_sdpa_full = query_sequence_length >= threshold && !mask &&
sdpa_full_supported_head_dim && stream.device == Device::gpu;
const bool supports_sdpa_vector = query_sequence_length == 1 &&
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
(query_sequence_length <= k.shape(-2)) &&
(!mask || mask->dtype() == bool_) && sdpa_vector_supported_head_dim &&
stream.device == Device::gpu;