Fix benchmark (#1175)

This commit is contained in:
nicolov 2024-06-04 16:50:46 +02:00 committed by GitHub
parent 3de8ce3f3c
commit 81def6ac76
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 19 additions and 3 deletions

View File

@ -185,7 +185,7 @@ def prelu(x: torch.Tensor) -> torch.Tensor:
def mish(x: torch.Tensor) -> torch.Tensor: def mish(x: torch.Tensor) -> torch.Tensor:
y = x y = x
for _ in range(100): for _ in range(100):
return torch.nn.functional.mish(y) y = torch.nn.functional.mish(y)
sync_if_needed(x) sync_if_needed(x)
@ -283,6 +283,14 @@ def topk(axis, x):
sync_if_needed(x) sync_if_needed(x)
@torch.no_grad()
def step_function(x):
y = x
for i in range(100):
y = torch.where(y < 0, 0, 1)
sync_if_needed(x)
@torch.no_grad() @torch.no_grad()
def selu(x): def selu(x):
y = x y = x
@ -446,5 +454,11 @@ if __name__ == "__main__":
elif args.benchmark == "topk": elif args.benchmark == "topk":
print(bench(topk, axis, x)) print(bench(topk, axis, x))
elif args.benchmark == "step":
print(bench(step_function, x))
elif args.benchmark == "selu":
print(bench(selu, x))
else: else:
raise ValueError("Unknown benchmark") raise ValueError(f"Unknown benchmark `{args.benchmark}`.")

View File

@ -16,7 +16,9 @@ def run_or_raise(*args, **kwargs):
result = run(*args, capture_output=True, **kwargs) result = run(*args, capture_output=True, **kwargs)
return float(result.stdout) return float(result.stdout)
except ValueError: except ValueError:
raise ValueError(f"stdout: {result.stdout}\nstderr: {result.stderr}") raise ValueError(
f"stdout: {result.stdout.decode()}\nstderr: {result.stderr.decode()}"
)
def compare(args): def compare(args):