From 0c22440c75709fecc2a3be0027e65aff21b4400b Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 20 Nov 2024 15:20:34 -0800 Subject: [PATCH] Update sdpa_benchmarks --- benchmarks/python/sdpa_bench.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/benchmarks/python/sdpa_bench.py b/benchmarks/python/sdpa_bench.py index ab94d99b0..56042e2d2 100644 --- a/benchmarks/python/sdpa_bench.py +++ b/benchmarks/python/sdpa_bench.py @@ -140,25 +140,22 @@ def get_gflop_count(B, M, N, K): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run gemm benchmarks") - dtypes = ( - "float16", - "float32", - ) + dtypes = ("float16", "float32") transposes = (False,) - # clang-format off + # fmt: off shapes = ( # ( B, qsl, ksl, head_dim, n_qh, n_kvh) - (1, 32, 32, 64, 32, 32), - (1, 64, 64, 64, 32, 32), - (1, 128, 128, 64, 32, 32), - (1, 256, 256, 64, 32, 32), - (1, 512, 512, 64, 32, 32), - (1, 1024, 1024, 64, 32, 32), - (1, 2048, 2048, 64, 32, 32), - (1, 4096, 4096, 64, 32, 32), + ( 1, 32, 32, 64, 32, 32), + ( 1, 64, 64, 64, 32, 32), + ( 1, 128, 128, 64, 32, 32), + ( 1, 256, 256, 64, 32, 32), + ( 1, 512, 512, 64, 32, 32), + ( 1, 1024, 1024, 64, 32, 32), + ( 1, 2048, 2048, 64, 32, 32), + ( 1, 4096, 4096, 64, 32, 32), ) - # clang-format on + # fmt: on print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")