diff --git a/benchmarks/python/comparative/bench_torch.py b/benchmarks/python/comparative/bench_torch.py index a2157707b..dd3436d9a 100644 --- a/benchmarks/python/comparative/bench_torch.py +++ b/benchmarks/python/comparative/bench_torch.py @@ -5,6 +5,7 @@ import os import time import torch +import torch.cuda import torch.mps @@ -44,8 +45,10 @@ def bench(f, *args): def sync_if_needed(x): - if x.device != torch.device("cpu"): + if x.device == torch.device("mps"): torch.mps.synchronize() + elif x.device == torch.device("cuda"): + torch.cuda.synchronize() @torch.no_grad() @@ -99,6 +102,14 @@ def reduction(op, axis, x): sync_if_needed(x) +@torch.no_grad() +def sum_and_add(axis, x, y): + z = x.sum(axis=axis, keepdims=True) + for i in range(50): + z = (z + y).sum(axis=axis, keepdims=True) + sync_if_needed(x) + + @torch.no_grad() def softmax(axis, x): ys = [] @@ -340,7 +351,11 @@ if __name__ == "__main__": args.axis.pop(0) torch.set_num_threads(1) - device = "cpu" if args.cpu else "mps" + device = "mps" + if torch.cuda.is_available(): + device = "cuda" + if args.cpu: + device = "cpu" types = args.dtype if not types: @@ -460,5 +475,8 @@ if __name__ == "__main__": elif args.benchmark == "selu": print(bench(selu, x)) + elif args.benchmark == "sum_and_add": + print(bench(sum_and_add, axis, *xs)) + else: raise ValueError(f"Unknown benchmark `{args.benchmark}`.")