Matrix Attention kernel (#1610)

* Rough INIT

* [WIP]: Loading and Matmuls added

* [WIP]: Reductions and min working aligned kernel at headdim = 64

* [WIP] Added headdim 80 for testing

* [WIP] Update dispatch params for testing

* [WIP] Add support for unaligned seq lengths - still looks messy

* Update sdpa_benchmarks

* Update sdpa_benchmarks

* Update sdpa_benchmarks

* Enable gqa support

* Update benchmark and switch off 128 headdim

* Update headdim 128 tuning

* Remove older fast attention code. Write out O strided

* Disable hd=128 until further optimizations

* Enable bf16

* Fix data size bug

* Enable attn build outside of jit
This commit is contained in:
Jagrit Digani
2024-11-22 10:34:05 -08:00
committed by GitHub
parent c79f6a4a8c
commit 02bec0bb6d
14 changed files with 2049 additions and 1109 deletions

View File

@@ -600,7 +600,7 @@ array scaled_dot_product_attention(
* * dtype is not fp32 or fp16
*/
int threshold = 1e6;
int threshold = 32; // TODO: Fix after dev
if (memory_efficient_threshold.has_value()) {
threshold = std::max(1, memory_efficient_threshold.value());
}
@@ -644,11 +644,10 @@ array scaled_dot_product_attention(
const bool sdpa_vector_supported_head_dim =
query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128;
const bool sdpa_full_supported_head_dim =
query_head_dim == 64 || query_head_dim == 128;
query_head_dim == 64 || query_head_dim == 80;
const bool supports_sdpa_full = query_sequence_length >= threshold &&
!mask.has_value() && sdpa_full_supported_head_dim &&
n_q_heads == n_kv_heads && final_type != bfloat16 &&
stream.device == Device::gpu;
const bool supports_sdpa_vector = query_sequence_length == 1 &&