mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	 a4667da1eb
			
		
	
	a4667da1eb
	
	
	
		
			
			* try faster synchronization move event fixes update bench fix fix * non-functioning kernel * try alternative fence * cleanup barrier * get rid of event_fence * update benchmarks * doc string in metal fence
		
			
				
	
	
		
			56 lines
		
	
	
		
			1.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			56 lines
		
	
	
		
			1.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import time
 | |
| 
 | |
| import mlx.core as mx
 | |
| 
 | |
| rank = mx.distributed.init().rank()
 | |
| 
 | |
| 
 | |
| def timeit(fn, a):
 | |
| 
 | |
|     # warmup
 | |
|     for _ in range(5):
 | |
|         mx.eval(fn(a))
 | |
| 
 | |
|     its = 10
 | |
|     tic = time.perf_counter()
 | |
|     for _ in range(its):
 | |
|         mx.eval(fn(a))
 | |
|     toc = time.perf_counter()
 | |
|     ms = 1000 * (toc - tic) / its
 | |
|     return ms
 | |
| 
 | |
| 
 | |
| def all_reduce_benchmark():
 | |
|     a = mx.ones((5, 5), mx.int32)
 | |
| 
 | |
|     its_per_eval = 100
 | |
| 
 | |
|     def fn(x):
 | |
|         for _ in range(its_per_eval):
 | |
|             x = mx.distributed.all_sum(x)
 | |
|             x = x - 1
 | |
|         return x
 | |
| 
 | |
|     ms = timeit(fn, a) / its_per_eval
 | |
|     if rank == 0:
 | |
|         print(f"All Reduce: time per iteration {ms:.6f} (ms)")
 | |
| 
 | |
| 
 | |
| def all_gather_benchmark():
 | |
|     a = mx.ones((5, 5), mx.int32)
 | |
|     its_per_eval = 100
 | |
| 
 | |
|     def fn(x):
 | |
|         for _ in range(its_per_eval):
 | |
|             x = mx.distributed.all_gather(x)[0]
 | |
|         return x
 | |
| 
 | |
|     ms = timeit(fn, a) / its_per_eval
 | |
|     if rank == 0:
 | |
|         print(f"All gather: time per iteration {ms:.6f} (ms)")
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     all_reduce_benchmark()
 | |
|     all_gather_benchmark()
 |