mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Fix benchmark (#1175)
This commit is contained in:
parent
3de8ce3f3c
commit
81def6ac76
@ -185,7 +185,7 @@ def prelu(x: torch.Tensor) -> torch.Tensor:
|
|||||||
def mish(x: torch.Tensor) -> torch.Tensor:
|
def mish(x: torch.Tensor) -> torch.Tensor:
|
||||||
y = x
|
y = x
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
return torch.nn.functional.mish(y)
|
y = torch.nn.functional.mish(y)
|
||||||
sync_if_needed(x)
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
@ -283,6 +283,14 @@ def topk(axis, x):
|
|||||||
sync_if_needed(x)
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step_function(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = torch.where(y < 0, 0, 1)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def selu(x):
|
def selu(x):
|
||||||
y = x
|
y = x
|
||||||
@ -446,5 +454,11 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "topk":
|
elif args.benchmark == "topk":
|
||||||
print(bench(topk, axis, x))
|
print(bench(topk, axis, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "step":
|
||||||
|
print(bench(step_function, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "selu":
|
||||||
|
print(bench(selu, x))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown benchmark")
|
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
||||||
|
@ -16,7 +16,9 @@ def run_or_raise(*args, **kwargs):
|
|||||||
result = run(*args, capture_output=True, **kwargs)
|
result = run(*args, capture_output=True, **kwargs)
|
||||||
return float(result.stdout)
|
return float(result.stdout)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError(f"stdout: {result.stdout}\nstderr: {result.stderr}")
|
raise ValueError(
|
||||||
|
f"stdout: {result.stdout.decode()}\nstderr: {result.stderr.decode()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def compare(args):
|
def compare(args):
|
||||||
|
Loading…
Reference in New Issue
Block a user