diff --git a/benchmarks/python/comparative/bench_torch.py b/benchmarks/python/comparative/bench_torch.py index a83e1b503..a2157707b 100644 --- a/benchmarks/python/comparative/bench_torch.py +++ b/benchmarks/python/comparative/bench_torch.py @@ -185,7 +185,7 @@ def prelu(x: torch.Tensor) -> torch.Tensor: def mish(x: torch.Tensor) -> torch.Tensor: y = x for _ in range(100): - return torch.nn.functional.mish(y) + y = torch.nn.functional.mish(y) sync_if_needed(x) @@ -283,6 +283,14 @@ def topk(axis, x): sync_if_needed(x) +@torch.no_grad() +def step_function(x): + y = x + for i in range(100): + y = torch.where(y < 0, 0, 1) + sync_if_needed(x) + + @torch.no_grad() def selu(x): y = x @@ -446,5 +454,11 @@ if __name__ == "__main__": elif args.benchmark == "topk": print(bench(topk, axis, x)) + elif args.benchmark == "step": + print(bench(step_function, x)) + + elif args.benchmark == "selu": + print(bench(selu, x)) + else: - raise ValueError("Unknown benchmark") + raise ValueError(f"Unknown benchmark `{args.benchmark}`.") diff --git a/benchmarks/python/comparative/compare.py b/benchmarks/python/comparative/compare.py index 5b71cf583..68b4a5bd3 100644 --- a/benchmarks/python/comparative/compare.py +++ b/benchmarks/python/comparative/compare.py @@ -16,7 +16,9 @@ def run_or_raise(*args, **kwargs): result = run(*args, capture_output=True, **kwargs) return float(result.stdout) except ValueError: - raise ValueError(f"stdout: {result.stdout}\nstderr: {result.stderr}") + raise ValueError( + f"stdout: {result.stdout.decode()}\nstderr: {result.stderr.decode()}" + ) def compare(args):