mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Fix benchmark (#1175)
This commit is contained in:
		| @@ -185,7 +185,7 @@ def prelu(x: torch.Tensor) -> torch.Tensor: | ||||
| def mish(x: torch.Tensor) -> torch.Tensor: | ||||
|     y = x | ||||
|     for _ in range(100): | ||||
|         return torch.nn.functional.mish(y) | ||||
|         y = torch.nn.functional.mish(y) | ||||
|     sync_if_needed(x) | ||||
|  | ||||
|  | ||||
| @@ -283,6 +283,14 @@ def topk(axis, x): | ||||
|     sync_if_needed(x) | ||||
|  | ||||
|  | ||||
| @torch.no_grad() | ||||
| def step_function(x): | ||||
|     y = x | ||||
|     for i in range(100): | ||||
|         y = torch.where(y < 0, 0, 1) | ||||
|     sync_if_needed(x) | ||||
|  | ||||
|  | ||||
| @torch.no_grad() | ||||
| def selu(x): | ||||
|     y = x | ||||
| @@ -446,5 +454,11 @@ if __name__ == "__main__": | ||||
|     elif args.benchmark == "topk": | ||||
|         print(bench(topk, axis, x)) | ||||
|  | ||||
|     elif args.benchmark == "step": | ||||
|         print(bench(step_function, x)) | ||||
|  | ||||
|     elif args.benchmark == "selu": | ||||
|         print(bench(selu, x)) | ||||
|  | ||||
|     else: | ||||
|         raise ValueError("Unknown benchmark") | ||||
|         raise ValueError(f"Unknown benchmark `{args.benchmark}`.") | ||||
|   | ||||
| @@ -16,7 +16,9 @@ def run_or_raise(*args, **kwargs): | ||||
|         result = run(*args, capture_output=True, **kwargs) | ||||
|         return float(result.stdout) | ||||
|     except ValueError: | ||||
|         raise ValueError(f"stdout: {result.stdout}\nstderr: {result.stderr}") | ||||
|         raise ValueError( | ||||
|             f"stdout: {result.stdout.decode()}\nstderr: {result.stderr.decode()}" | ||||
|         ) | ||||
|  | ||||
|  | ||||
| def compare(args): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 nicolov
					nicolov