Update benchmark and switch off 128 headdim

This commit is contained in:
Jagrit Digani 2024-11-20 15:37:36 -08:00
parent 140301aea8
commit 791f50d9f3
2 changed files with 19 additions and 10 deletions

View File

@ -140,23 +140,32 @@ def get_gflop_count(B, M, N, K):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
dtypes = ("float16", "float32")
dtypes = ("float16", "float32")[:1]
transposes = (False,)
# fmt: off
shapes = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 32, 32, 64, 32, 32),
( 1, 64, 64, 64, 32, 32),
( 1, 128, 128, 64, 32, 32),
( 1, 256, 256, 64, 32, 32),
( 1, 512, 512, 64, 32, 32),
( 1, 1024, 1024, 64, 32, 32),
( 1, 2048, 2048, 64, 32, 32),
( 1, 4096, 4096, 64, 32, 32),
( 1, 32, 32, 64, 32, 32),
( 1, 64, 64, 64, 32, 32),
( 1, 128, 128, 64, 32, 32),
( 1, 256, 256, 64, 32, 32),
( 1, 512, 512, 64, 32, 32),
( 1, 1024, 1024, 64, 32, 32),
( 1, 2048, 2048, 64, 32, 32),
( 1, 4096, 4096, 64, 32, 32),
)
shapes_80 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 80, 32, 32),
( 1, 2048, 2048, 80, 32, 32),
( 1, 4096, 4096, 80, 32, 32),
)
# fmt: on
shapes = shapes + shapes_80
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
for dtype in dtypes:

View File

@ -644,7 +644,7 @@ array scaled_dot_product_attention(
const bool sdpa_vector_supported_head_dim =
query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128;
const bool sdpa_full_supported_head_dim =
query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128;
query_head_dim == 64 || query_head_dim == 80;
const bool supports_sdpa_full = query_sequence_length >= threshold &&
!mask.has_value() && sdpa_full_supported_head_dim &&