sdpa specialization for head dim 256 (#2007)

This commit is contained in:
Awni Hannun
2025-03-27 19:31:25 -07:00
committed by GitHub
parent a6b5d6e759
commit bc62932984
2 changed files with 5 additions and 2 deletions

View File

@@ -720,7 +720,8 @@ array scaled_dot_product_attention(
const bool sdpa_vector_supported_head_dim =
query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
query_head_dim == 256);
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);