mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Enabling fused attention for head dim 128 (#1899)
* Share KV smem * Fix bfloat error * Unroll O = S @ V loop * Perf upgrade * Remove commented out function * Add -Wno-c++17-extensions flag to metal flags * Add -Wno-c++17-extensions flag to metal extension flags
This commit is contained in:
@@ -693,7 +693,7 @@ array scaled_dot_product_attention(
|
||||
query_head_dim == value_head_dim &&
|
||||
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
|
||||
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
|
||||
(query_head_dim == 64 || query_head_dim == 80);
|
||||
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
|
||||
|
||||
const bool supports_sdpa_full = query_sequence_length >= threshold && !mask &&
|
||||
sdpa_full_supported_head_dim && stream.device == Device::gpu;
|
||||
|
||||
Reference in New Issue
Block a user