mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-08 20:31:13 +08:00
[WIP] Added headdim 80 for testing
This commit is contained in:
parent
d927ed9e32
commit
2cd1de0e47
@ -21,6 +21,8 @@
|
|||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
#define instantiate_attn_shapes_helper(iname, itype) \
|
#define instantiate_attn_shapes_helper(iname, itype) \
|
||||||
|
instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \
|
||||||
|
instantiate_attn(iname, itype, 32, 16, 80, 4, 1) \
|
||||||
instantiate_attn(iname, itype, 32, 32, 64, 4, 1) \
|
instantiate_attn(iname, itype, 32, 32, 64, 4, 1) \
|
||||||
instantiate_attn(iname, itype, 32, 16, 64, 4, 1) \
|
instantiate_attn(iname, itype, 32, 16, 64, 4, 1) \
|
||||||
|
|
||||||
|
@ -30,8 +30,8 @@ void sdpa_full_self_attention_metal(
|
|||||||
int wn = 1;
|
int wn = 1;
|
||||||
|
|
||||||
int bq = 32;
|
int bq = 32;
|
||||||
int bk = 16;
|
int bk = 32;
|
||||||
int bd = 64;
|
int bd = q.shape(-1);
|
||||||
|
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
kname << "steel_attention_" << type_to_name(q) << "_bq" << bq << "_bk" << bk
|
kname << "steel_attention_" << type_to_name(q) << "_bq" << bq << "_bk" << bk
|
||||||
|
@ -644,7 +644,7 @@ array scaled_dot_product_attention(
|
|||||||
const bool sdpa_vector_supported_head_dim =
|
const bool sdpa_vector_supported_head_dim =
|
||||||
query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128;
|
query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128;
|
||||||
const bool sdpa_full_supported_head_dim =
|
const bool sdpa_full_supported_head_dim =
|
||||||
query_head_dim == 64 || query_head_dim == 128;
|
query_head_dim == 64 || query_head_dim == 80;
|
||||||
|
|
||||||
const bool supports_sdpa_full = query_sequence_length >= threshold &&
|
const bool supports_sdpa_full = query_sequence_length >= threshold &&
|
||||||
!mask.has_value() && sdpa_full_supported_head_dim &&
|
!mask.has_value() && sdpa_full_supported_head_dim &&
|
||||||
|
Loading…
Reference in New Issue
Block a user