Update sdpa_benchmarks

This commit is contained in:
Jagrit Digani 2024-11-20 15:20:34 -08:00
parent c9ab537b9a
commit 0c22440c75

View File

@ -140,25 +140,22 @@ def get_gflop_count(B, M, N, K):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run gemm benchmarks") parser = argparse.ArgumentParser(description="Run gemm benchmarks")
dtypes = ( dtypes = ("float16", "float32")
"float16",
"float32",
)
transposes = (False,) transposes = (False,)
# clang-format off # fmt: off
shapes = ( shapes = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh) # ( B, qsl, ksl, head_dim, n_qh, n_kvh)
(1, 32, 32, 64, 32, 32), ( 1, 32, 32, 64, 32, 32),
(1, 64, 64, 64, 32, 32), ( 1, 64, 64, 64, 32, 32),
(1, 128, 128, 64, 32, 32), ( 1, 128, 128, 64, 32, 32),
(1, 256, 256, 64, 32, 32), ( 1, 256, 256, 64, 32, 32),
(1, 512, 512, 64, 32, 32), ( 1, 512, 512, 64, 32, 32),
(1, 1024, 1024, 64, 32, 32), ( 1, 1024, 1024, 64, 32, 32),
(1, 2048, 2048, 64, 32, 32), ( 1, 2048, 2048, 64, 32, 32),
(1, 4096, 4096, 64, 32, 32), ( 1, 4096, 4096, 64, 32, 32),
) )
# clang-format on # fmt: on
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%") print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")