diff --git a/benchmarks/python/blas/bench_gemm.py b/benchmarks/python/blas/bench_gemm.py index 4914c40ba..65dafcbbe 100644 --- a/benchmarks/python/blas/bench_gemm.py +++ b/benchmarks/python/blas/bench_gemm.py @@ -157,7 +157,7 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"): def get_gflop_count(B, M, N, K): - return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3) + return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1000.0**3) if __name__ == "__main__": @@ -175,6 +175,8 @@ if __name__ == "__main__": (1, 4096, 4096, 4096), ) + print(f" B, M, N, K, dtype, t, gflops_pt, gflops_mx, diff%") + for dtype in dtypes: for transpose in transposes: for B, M, N, K in shapes: @@ -187,7 +189,7 @@ if __name__ == "__main__": diff = gflops_mx / gflops_pt - 1.0 print( - f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%" + f"{B:3d}, {M:4d}, {N:4d}, {K:5d}, {dtype}, {transpose}, {gflops_pt:8.2f}, {gflops_mx:8.2f}, {100. * diff:+5.2f}%" ) if gflops_pt >= 2.0 * gflops_mx: print("ATTENTION ^^^^^^^")