mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 18:11:17 +08:00
sdpa specialization for head dim 256 (#2007)
This commit is contained in:
parent
a6b5d6e759
commit
bc62932984
@ -32,9 +32,11 @@ using namespace metal;
|
||||
instantiate_sdpa_vector(type, 64, 64) \
|
||||
instantiate_sdpa_vector(type, 96, 96) \
|
||||
instantiate_sdpa_vector(type, 128, 128) \
|
||||
instantiate_sdpa_vector(type, 256, 256) \
|
||||
instantiate_sdpa_vector_aggregation(type, 64) \
|
||||
instantiate_sdpa_vector_aggregation(type, 96) \
|
||||
instantiate_sdpa_vector_aggregation(type, 128)
|
||||
instantiate_sdpa_vector_aggregation(type, 128) \
|
||||
instantiate_sdpa_vector_aggregation(type, 256)
|
||||
|
||||
instantiate_sdpa_vector_heads(float)
|
||||
instantiate_sdpa_vector_heads(bfloat16_t)
|
||||
|
@ -720,7 +720,8 @@ array scaled_dot_product_attention(
|
||||
|
||||
const bool sdpa_vector_supported_head_dim =
|
||||
query_head_dim == value_head_dim &&
|
||||
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
|
||||
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
|
||||
query_head_dim == 256);
|
||||
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
|
||||
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user