mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-07 04:06:42 +08:00
Enable gqa support
This commit is contained in:
parent
0c22440c75
commit
140301aea8
@ -91,11 +91,12 @@ template <
|
||||
tidl.y * params->Q_strides[1] + // Head
|
||||
tidl.x * BQ * params->Q_strides[2]; // Seqeunce
|
||||
|
||||
ulong kv_head_idx = int(tid.y) / params->gqa_factor;
|
||||
K += tidl.z * params->K_strides[0] + // Batch
|
||||
tidl.y * params->K_strides[1]; // Head
|
||||
kv_head_idx * params->K_strides[1]; // Head
|
||||
|
||||
V += tidl.z * params->V_strides[0] + // Batch
|
||||
tidl.y * params->V_strides[1]; // Head
|
||||
kv_head_idx * params->V_strides[1]; // Head
|
||||
|
||||
O += tidl.z * params->O_strides[0] + // Batch
|
||||
tidl.y * params->O_strides[1] + // Head
|
||||
|
@ -644,12 +644,11 @@ array scaled_dot_product_attention(
|
||||
const bool sdpa_vector_supported_head_dim =
|
||||
query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128;
|
||||
const bool sdpa_full_supported_head_dim =
|
||||
query_head_dim == 64 || query_head_dim == 80;
|
||||
query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128;
|
||||
|
||||
const bool supports_sdpa_full = query_sequence_length >= threshold &&
|
||||
!mask.has_value() && sdpa_full_supported_head_dim &&
|
||||
n_q_heads == n_kv_heads && final_type != bfloat16 &&
|
||||
stream.device == Device::gpu;
|
||||
final_type != bfloat16 && stream.device == Device::gpu;
|
||||
|
||||
const bool supports_sdpa_vector = query_sequence_length == 1 &&
|
||||
!mask.has_value() && sdpa_vector_supported_head_dim &&
|
||||
|
Loading…
Reference in New Issue
Block a user