mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 02:41:19 +08:00
Update headdim 128 tuning
This commit is contained in:
parent
791f50d9f3
commit
d571366250
@ -144,7 +144,7 @@ if __name__ == "__main__":
|
||||
transposes = (False,)
|
||||
|
||||
# fmt: off
|
||||
shapes = (
|
||||
shapes_64 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 32, 32, 64, 32, 32),
|
||||
( 1, 64, 64, 64, 32, 32),
|
||||
@ -162,9 +162,16 @@ if __name__ == "__main__":
|
||||
( 1, 2048, 2048, 80, 32, 32),
|
||||
( 1, 4096, 4096, 80, 32, 32),
|
||||
)
|
||||
|
||||
shapes_128 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 1024, 1024, 128, 32, 32),
|
||||
( 1, 2048, 2048, 128, 32, 32),
|
||||
( 1, 4096, 4096, 128, 32, 32),
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
shapes = shapes + shapes_80
|
||||
shapes = shapes_64 + shapes_80 + shapes_128
|
||||
|
||||
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
|
||||
|
||||
|
@ -21,7 +21,7 @@
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_attn_shapes_helper(iname, itype) \
|
||||
instantiate_attn(iname, itype, 32, 32, 128, 4, 1) \
|
||||
instantiate_attn(iname, itype, 32, 16, 128, 4, 1) \
|
||||
instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \
|
||||
instantiate_attn(iname, itype, 32, 32, 64, 4, 1)
|
||||
|
||||
|
@ -29,9 +29,9 @@ void sdpa_full_self_attention_metal(
|
||||
int wm = 4;
|
||||
int wn = 1;
|
||||
|
||||
int bq = 32;
|
||||
int bk = 32;
|
||||
int bd = q.shape(-1);
|
||||
int bq = 32;
|
||||
int bk = bd < 128 ? 32 : 16;
|
||||
|
||||
int B = q.shape(0);
|
||||
int H = q.shape(1);
|
||||
|
@ -644,7 +644,7 @@ array scaled_dot_product_attention(
|
||||
const bool sdpa_vector_supported_head_dim =
|
||||
query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128;
|
||||
const bool sdpa_full_supported_head_dim =
|
||||
query_head_dim == 64 || query_head_dim == 80;
|
||||
query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128;
|
||||
|
||||
const bool supports_sdpa_full = query_sequence_length >= threshold &&
|
||||
!mask.has_value() && sdpa_full_supported_head_dim &&
|
||||
|
Loading…
Reference in New Issue
Block a user