mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	MPI ops in GPU stream for faster comms (#1356)
This commit is contained in:
		
							
								
								
									
										66
									
								
								benchmarks/python/distributed_bench.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								benchmarks/python/distributed_bench.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,66 @@ | ||||
| # Copyright © 2024 Apple Inc. | ||||
|  | ||||
| """ | ||||
| Run with: | ||||
|     mpirun -n 2 python /path/to/distributed_bench.py | ||||
| """ | ||||
|  | ||||
| import time | ||||
|  | ||||
| import mlx.core as mx | ||||
|  | ||||
|  | ||||
| def time_fn(fn, *args, **kwargs): | ||||
|     msg = kwargs.pop("msg", None) | ||||
|     world = mx.distributed.init() | ||||
|     if world.rank() == 0: | ||||
|         if msg: | ||||
|             print(f"Timing {msg} ...", end=" ") | ||||
|         else: | ||||
|             print(f"Timing {fn.__name__} ...", end=" ") | ||||
|  | ||||
|     # warmup | ||||
|     for _ in range(5): | ||||
|         mx.eval(fn(*args, **kwargs)) | ||||
|  | ||||
|     num_iters = 100 | ||||
|     tic = time.perf_counter() | ||||
|     for _ in range(num_iters): | ||||
|         x = mx.eval(fn(*args, **kwargs)) | ||||
|     toc = time.perf_counter() | ||||
|  | ||||
|     msec = 1e3 * (toc - tic) / num_iters | ||||
|     if world.rank() == 0: | ||||
|         print(f"{msec:.5f} msec") | ||||
|  | ||||
|  | ||||
| def time_all_sum(): | ||||
|     shape = (4096,) | ||||
|     x = mx.random.uniform(shape=shape) | ||||
|     mx.eval(x) | ||||
|  | ||||
|     def sine(x): | ||||
|         for _ in range(20): | ||||
|             x = mx.sin(x) | ||||
|         return x | ||||
|  | ||||
|     time_fn(sine, x) | ||||
|  | ||||
|     def all_sum_plain(x): | ||||
|         for _ in range(20): | ||||
|             x = mx.distributed.all_sum(x) | ||||
|         return x | ||||
|  | ||||
|     time_fn(all_sum_plain, x) | ||||
|  | ||||
|     def all_sum_with_sine(x): | ||||
|         for _ in range(20): | ||||
|             x = mx.sin(x) | ||||
|             x = mx.distributed.all_sum(x) | ||||
|         return x | ||||
|  | ||||
|     time_fn(all_sum_with_sine, x) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     time_all_sum() | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun