mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	[CUDA] Fix reductions (#2314)
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							2c11d10f8d
						
					
				
				
					commit
					772f471ff2
				
			| @@ -5,6 +5,7 @@ import os | ||||
| import time | ||||
|  | ||||
| import torch | ||||
| import torch.cuda | ||||
| import torch.mps | ||||
|  | ||||
|  | ||||
| @@ -44,8 +45,10 @@ def bench(f, *args): | ||||
|  | ||||
|  | ||||
| def sync_if_needed(x): | ||||
|     if x.device != torch.device("cpu"): | ||||
|     if x.device == torch.device("mps"): | ||||
|         torch.mps.synchronize() | ||||
|     elif x.device == torch.device("cuda"): | ||||
|         torch.cuda.synchronize() | ||||
|  | ||||
|  | ||||
| @torch.no_grad() | ||||
| @@ -99,6 +102,14 @@ def reduction(op, axis, x): | ||||
|     sync_if_needed(x) | ||||
|  | ||||
|  | ||||
| @torch.no_grad() | ||||
| def sum_and_add(axis, x, y): | ||||
|     z = x.sum(axis=axis, keepdims=True) | ||||
|     for i in range(50): | ||||
|         z = (z + y).sum(axis=axis, keepdims=True) | ||||
|     sync_if_needed(x) | ||||
|  | ||||
|  | ||||
| @torch.no_grad() | ||||
| def softmax(axis, x): | ||||
|     ys = [] | ||||
| @@ -340,7 +351,11 @@ if __name__ == "__main__": | ||||
|         args.axis.pop(0) | ||||
|  | ||||
|     torch.set_num_threads(1) | ||||
|     device = "cpu" if args.cpu else "mps" | ||||
|     device = "mps" | ||||
|     if torch.cuda.is_available(): | ||||
|         device = "cuda" | ||||
|     if args.cpu: | ||||
|         device = "cpu" | ||||
|  | ||||
|     types = args.dtype | ||||
|     if not types: | ||||
| @@ -460,5 +475,8 @@ if __name__ == "__main__": | ||||
|     elif args.benchmark == "selu": | ||||
|         print(bench(selu, x)) | ||||
|  | ||||
|     elif args.benchmark == "sum_and_add": | ||||
|         print(bench(sum_and_add, axis, *xs)) | ||||
|  | ||||
|     else: | ||||
|         raise ValueError(f"Unknown benchmark `{args.benchmark}`.") | ||||
|   | ||||
		Reference in New Issue
	
	Block a user