diff --git a/benchmarks/python/sdpa_bench.py b/benchmarks/python/sdpa_bench.py index 8b3d43694..ab94d99b0 100644 --- a/benchmarks/python/sdpa_bench.py +++ b/benchmarks/python/sdpa_bench.py @@ -145,6 +145,8 @@ if __name__ == "__main__": "float32", ) transposes = (False,) + + # clang-format off shapes = ( # ( B, qsl, ksl, head_dim, n_qh, n_kvh) (1, 32, 32, 64, 32, 32), @@ -156,6 +158,7 @@ if __name__ == "__main__": (1, 2048, 2048, 64, 32, 32), (1, 4096, 4096, 64, 32, 32), ) + # clang-format on print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")