mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			67 lines
		
	
	
		
			1.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			67 lines
		
	
	
		
			1.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright © 2024 Apple Inc.
 | |
| 
 | |
| """
 | |
| Run with:
 | |
|     mpirun -n 2 python /path/to/distributed_bench.py
 | |
| """
 | |
| 
 | |
| import time
 | |
| 
 | |
| import mlx.core as mx
 | |
| 
 | |
| 
 | |
| def time_fn(fn, *args, **kwargs):
 | |
|     msg = kwargs.pop("msg", None)
 | |
|     world = mx.distributed.init()
 | |
|     if world.rank() == 0:
 | |
|         if msg:
 | |
|             print(f"Timing {msg} ...", end=" ")
 | |
|         else:
 | |
|             print(f"Timing {fn.__name__} ...", end=" ")
 | |
| 
 | |
|     # warmup
 | |
|     for _ in range(5):
 | |
|         mx.eval(fn(*args, **kwargs))
 | |
| 
 | |
|     num_iters = 100
 | |
|     tic = time.perf_counter()
 | |
|     for _ in range(num_iters):
 | |
|         x = mx.eval(fn(*args, **kwargs))
 | |
|     toc = time.perf_counter()
 | |
| 
 | |
|     msec = 1e3 * (toc - tic) / num_iters
 | |
|     if world.rank() == 0:
 | |
|         print(f"{msec:.5f} msec")
 | |
| 
 | |
| 
 | |
| def time_all_sum():
 | |
|     shape = (4096,)
 | |
|     x = mx.random.uniform(shape=shape)
 | |
|     mx.eval(x)
 | |
| 
 | |
|     def sine(x):
 | |
|         for _ in range(20):
 | |
|             x = mx.sin(x)
 | |
|         return x
 | |
| 
 | |
|     time_fn(sine, x)
 | |
| 
 | |
|     def all_sum_plain(x):
 | |
|         for _ in range(20):
 | |
|             x = mx.distributed.all_sum(x)
 | |
|         return x
 | |
| 
 | |
|     time_fn(all_sum_plain, x)
 | |
| 
 | |
|     def all_sum_with_sine(x):
 | |
|         for _ in range(20):
 | |
|             x = mx.sin(x)
 | |
|             x = mx.distributed.all_sum(x)
 | |
|         return x
 | |
| 
 | |
|     time_fn(all_sum_with_sine, x)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     time_all_sum()
 | 
