mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-08 12:05:34 +08:00
[WIP] Added headdim 80 for testing
This commit is contained in:
parent
d927ed9e32
commit
2cd1de0e47
@ -21,6 +21,8 @@
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_attn_shapes_helper(iname, itype) \
|
||||
instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \
|
||||
instantiate_attn(iname, itype, 32, 16, 80, 4, 1) \
|
||||
instantiate_attn(iname, itype, 32, 32, 64, 4, 1) \
|
||||
instantiate_attn(iname, itype, 32, 16, 64, 4, 1) \
|
||||
|
||||
|
@ -30,8 +30,8 @@ void sdpa_full_self_attention_metal(
|
||||
int wn = 1;
|
||||
|
||||
int bq = 32;
|
||||
int bk = 16;
|
||||
int bd = 64;
|
||||
int bk = 32;
|
||||
int bd = q.shape(-1);
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "steel_attention_" << type_to_name(q) << "_bq" << bq << "_bk" << bk
|
||||
|
@ -644,7 +644,7 @@ array scaled_dot_product_attention(
|
||||
const bool sdpa_vector_supported_head_dim =
|
||||
query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128;
|
||||
const bool sdpa_full_supported_head_dim =
|
||||
query_head_dim == 64 || query_head_dim == 128;
|
||||
query_head_dim == 64 || query_head_dim == 80;
|
||||
|
||||
const bool supports_sdpa_full = query_sequence_length >= threshold &&
|
||||
!mask.has_value() && sdpa_full_supported_head_dim &&
|
||||
|
Loading…
Reference in New Issue
Block a user