Enable gqa support

This commit is contained in:
Jagrit Digani 2024-11-20 15:24:41 -08:00
parent 0c22440c75
commit 140301aea8
2 changed files with 5 additions and 5 deletions

View File

@ -91,11 +91,12 @@ template <
tidl.y * params->Q_strides[1] + // Head
tidl.x * BQ * params->Q_strides[2]; // Seqeunce
ulong kv_head_idx = int(tid.y) / params->gqa_factor;
K += tidl.z * params->K_strides[0] + // Batch
tidl.y * params->K_strides[1]; // Head
kv_head_idx * params->K_strides[1]; // Head
V += tidl.z * params->V_strides[0] + // Batch
tidl.y * params->V_strides[1]; // Head
kv_head_idx * params->V_strides[1]; // Head
O += tidl.z * params->O_strides[0] + // Batch
tidl.y * params->O_strides[1] + // Head

View File

@ -644,12 +644,11 @@ array scaled_dot_product_attention(
const bool sdpa_vector_supported_head_dim =
query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128;
const bool sdpa_full_supported_head_dim =
query_head_dim == 64 || query_head_dim == 80;
query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128;
const bool supports_sdpa_full = query_sequence_length >= threshold &&
!mask.has_value() && sdpa_full_supported_head_dim &&
n_q_heads == n_kv_heads && final_type != bfloat16 &&
stream.device == Device::gpu;
final_type != bfloat16 && stream.device == Device::gpu;
const bool supports_sdpa_vector = query_sequence_length == 1 &&
!mask.has_value() && sdpa_vector_supported_head_dim &&