mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			97 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			97 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright © 2023-2024 Apple Inc.
 | |
| 
 | |
| import argparse
 | |
| 
 | |
| import mlx.core as mx
 | |
| import torch
 | |
| from time_utils import measure_runtime
 | |
| 
 | |
| 
 | |
| def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
 | |
|     def scatter(dst, x, idx):
 | |
|         dst[tuple(idx)] = x
 | |
|         mx.eval(dst)
 | |
| 
 | |
|     idx = []
 | |
|     for idx_shape in idx_shapes:
 | |
|         idx.append(mx.random.randint(0, dst_shape[0] - 1, idx_shape))
 | |
|     x = mx.random.normal(x_shape).astype(mx.float32)
 | |
|     dst = mx.random.normal(dst_shape).astype(mx.float32)
 | |
| 
 | |
|     runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx)
 | |
|     print(f"MLX: {runtime:.3f}ms")
 | |
| 
 | |
| 
 | |
| def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
 | |
|     def scatter(dst, x, idx, device):
 | |
|         dst[tuple(idx)] = x
 | |
|         if device == torch.device("mps"):
 | |
|             torch.mps.synchronize()
 | |
| 
 | |
|     idx = []
 | |
|     for idx_shape in idx_shapes:
 | |
|         idx.append(torch.randint(0, dst_shape[0] - 1, idx_shape).to(device))
 | |
|     x = torch.randn(x_shape, dtype=torch.float32).to(device)
 | |
|     dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
 | |
| 
 | |
|     runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device)
 | |
|     print(f"PyTorch: {runtime:.3f}ms")
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     parser = argparse.ArgumentParser("Gather benchmarks.")
 | |
|     parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
 | |
|     args = parser.parse_args()
 | |
| 
 | |
|     if args.cpu:
 | |
|         mx.set_default_device(mx.cpu)
 | |
|         device = torch.device("cpu")
 | |
|     else:
 | |
|         device = torch.device("mps")
 | |
| 
 | |
|     dst_shapes = [
 | |
|         (10, 64),
 | |
|         (100_000, 64),
 | |
|         (1_000_000, 64),
 | |
|         (100_000,),
 | |
|         (200_000,),
 | |
|         (20_000_000,),
 | |
|         (10000, 64),
 | |
|         (100, 64),
 | |
|         (100, 10_000, 64),
 | |
|         (10, 100, 100, 21),
 | |
|         (1_000, 1_000, 10),
 | |
|     ]
 | |
|     idx_shapes = [
 | |
|         [(1_000_000,)],
 | |
|         [(1_000_000,)],
 | |
|         [(100_000,)],
 | |
|         [(1_000_000,)],
 | |
|         [(20_000_000,)],
 | |
|         [(20_000_000,)],
 | |
|         [(1000000,)],
 | |
|         [(10000000,)],
 | |
|         [(1_000,)],
 | |
|         [(10_000,)],
 | |
|         [(1_000,), (1_000,)],
 | |
|     ]
 | |
|     x_shapes = [
 | |
|         (1_000_000, 64),
 | |
|         (1_000_000, 64),
 | |
|         (100_000, 64),
 | |
|         (1_000_000,),
 | |
|         (20_000_000,),
 | |
|         (20_000_000,),
 | |
|         (1000000, 64),
 | |
|         (10000000, 64),
 | |
|         (1_000, 10_000, 64),
 | |
|         (10_000, 100, 100, 21),
 | |
|         (1_000, 10),
 | |
|     ]
 | |
| 
 | |
|     for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
 | |
|         print("=" * 20)
 | |
|         print(f"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}")
 | |
|         benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
 | |
|         benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)
 | 
