mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 00:31:12 +08:00
Update benchmark and switch off 128 headdim
This commit is contained in:
parent
140301aea8
commit
791f50d9f3
@ -140,7 +140,7 @@ def get_gflop_count(B, M, N, K):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
||||||
|
|
||||||
dtypes = ("float16", "float32")
|
dtypes = ("float16", "float32")[:1]
|
||||||
transposes = (False,)
|
transposes = (False,)
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@ -155,8 +155,17 @@ if __name__ == "__main__":
|
|||||||
( 1, 2048, 2048, 64, 32, 32),
|
( 1, 2048, 2048, 64, 32, 32),
|
||||||
( 1, 4096, 4096, 64, 32, 32),
|
( 1, 4096, 4096, 64, 32, 32),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
shapes_80 = (
|
||||||
|
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||||
|
( 1, 1024, 1024, 80, 32, 32),
|
||||||
|
( 1, 2048, 2048, 80, 32, 32),
|
||||||
|
( 1, 4096, 4096, 80, 32, 32),
|
||||||
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
shapes = shapes + shapes_80
|
||||||
|
|
||||||
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
|
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
|
||||||
|
|
||||||
for dtype in dtypes:
|
for dtype in dtypes:
|
||||||
|
@ -644,7 +644,7 @@ array scaled_dot_product_attention(
|
|||||||
const bool sdpa_vector_supported_head_dim =
|
const bool sdpa_vector_supported_head_dim =
|
||||||
query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128;
|
query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128;
|
||||||
const bool sdpa_full_supported_head_dim =
|
const bool sdpa_full_supported_head_dim =
|
||||||
query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128;
|
query_head_dim == 64 || query_head_dim == 80;
|
||||||
|
|
||||||
const bool supports_sdpa_full = query_sequence_length >= threshold &&
|
const bool supports_sdpa_full = query_sequence_length >= threshold &&
|
||||||
!mask.has_value() && sdpa_full_supported_head_dim &&
|
!mask.has_value() && sdpa_full_supported_head_dim &&
|
||||||
|
Loading…
Reference in New Issue
Block a user