[WIP] Added headdim 80 for testing

This commit is contained in:
Jagrit Digani 2024-11-19 18:21:51 -08:00
parent d927ed9e32
commit 2cd1de0e47
3 changed files with 5 additions and 3 deletions

View File

@ -21,6 +21,8 @@
uint3 lid [[thread_position_in_threadgroup]]); uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_attn_shapes_helper(iname, itype) \ #define instantiate_attn_shapes_helper(iname, itype) \
instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \
instantiate_attn(iname, itype, 32, 16, 80, 4, 1) \
instantiate_attn(iname, itype, 32, 32, 64, 4, 1) \ instantiate_attn(iname, itype, 32, 32, 64, 4, 1) \
instantiate_attn(iname, itype, 32, 16, 64, 4, 1) \ instantiate_attn(iname, itype, 32, 16, 64, 4, 1) \

View File

@ -30,8 +30,8 @@ void sdpa_full_self_attention_metal(
int wn = 1; int wn = 1;
int bq = 32; int bq = 32;
int bk = 16; int bk = 32;
int bd = 64; int bd = q.shape(-1);
std::ostringstream kname; std::ostringstream kname;
kname << "steel_attention_" << type_to_name(q) << "_bq" << bq << "_bk" << bk kname << "steel_attention_" << type_to_name(q) << "_bq" << bq << "_bk" << bk

View File

@ -644,7 +644,7 @@ array scaled_dot_product_attention(
const bool sdpa_vector_supported_head_dim = const bool sdpa_vector_supported_head_dim =
query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128; query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128;
const bool sdpa_full_supported_head_dim = const bool sdpa_full_supported_head_dim =
query_head_dim == 64 || query_head_dim == 128; query_head_dim == 64 || query_head_dim == 80;
const bool supports_sdpa_full = query_sequence_length >= threshold && const bool supports_sdpa_full = query_sequence_length >= threshold &&
!mask.has_value() && sdpa_full_supported_head_dim && !mask.has_value() && sdpa_full_supported_head_dim &&