mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	* try faster synchronization move event fixes update bench fix fix * non-functioning kernel * try alternative fence * cleanup barrier * get rid of event_fence * update benchmarks * doc string in metal fence
		
			
				
	
	
		
			56 lines
		
	
	
		
			1.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			56 lines
		
	
	
		
			1.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import time
 | 
						|
 | 
						|
import mlx.core as mx
 | 
						|
 | 
						|
rank = mx.distributed.init().rank()
 | 
						|
 | 
						|
 | 
						|
def timeit(fn, a):
 | 
						|
 | 
						|
    # warmup
 | 
						|
    for _ in range(5):
 | 
						|
        mx.eval(fn(a))
 | 
						|
 | 
						|
    its = 10
 | 
						|
    tic = time.perf_counter()
 | 
						|
    for _ in range(its):
 | 
						|
        mx.eval(fn(a))
 | 
						|
    toc = time.perf_counter()
 | 
						|
    ms = 1000 * (toc - tic) / its
 | 
						|
    return ms
 | 
						|
 | 
						|
 | 
						|
def all_reduce_benchmark():
 | 
						|
    a = mx.ones((5, 5), mx.int32)
 | 
						|
 | 
						|
    its_per_eval = 100
 | 
						|
 | 
						|
    def fn(x):
 | 
						|
        for _ in range(its_per_eval):
 | 
						|
            x = mx.distributed.all_sum(x)
 | 
						|
            x = x - 1
 | 
						|
        return x
 | 
						|
 | 
						|
    ms = timeit(fn, a) / its_per_eval
 | 
						|
    if rank == 0:
 | 
						|
        print(f"All Reduce: time per iteration {ms:.6f} (ms)")
 | 
						|
 | 
						|
 | 
						|
def all_gather_benchmark():
 | 
						|
    a = mx.ones((5, 5), mx.int32)
 | 
						|
    its_per_eval = 100
 | 
						|
 | 
						|
    def fn(x):
 | 
						|
        for _ in range(its_per_eval):
 | 
						|
            x = mx.distributed.all_gather(x)[0]
 | 
						|
        return x
 | 
						|
 | 
						|
    ms = timeit(fn, a) / its_per_eval
 | 
						|
    if rank == 0:
 | 
						|
        print(f"All gather: time per iteration {ms:.6f} (ms)")
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    all_reduce_benchmark()
 | 
						|
    all_gather_benchmark()
 |