Enabling fused attention for head dim 128 (#1899)

* Share KV smem

* Fix bfloat error

* Unroll O = S @ V loop

* Perf upgrade

* Remove commented out function

* Add -Wno-c++17-extensions flag to metal flags

* Add -Wno-c++17-extensions flag to metal extension flags
This commit is contained in:
Jagrit Digani
2025-02-26 10:02:06 -08:00
committed by GitHub
parent 6bf00ef631
commit 89d327075f
5 changed files with 102 additions and 46 deletions

View File

@@ -693,7 +693,7 @@ array scaled_dot_product_attention(
query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 80);
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
const bool supports_sdpa_full = query_sequence_length >= threshold && !mask &&
sdpa_full_supported_head_dim && stream.device == Device::gpu;