mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 02:28:13 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			67 lines
		
	
	
		
			1.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			67 lines
		
	
	
		
			1.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright © 2024 Apple Inc.
 | 
						|
 | 
						|
"""
 | 
						|
Run with:
 | 
						|
    mpirun -n 2 python /path/to/distributed_bench.py
 | 
						|
"""
 | 
						|
 | 
						|
import time
 | 
						|
 | 
						|
import mlx.core as mx
 | 
						|
 | 
						|
 | 
						|
def time_fn(fn, *args, **kwargs):
 | 
						|
    msg = kwargs.pop("msg", None)
 | 
						|
    world = mx.distributed.init()
 | 
						|
    if world.rank() == 0:
 | 
						|
        if msg:
 | 
						|
            print(f"Timing {msg} ...", end=" ")
 | 
						|
        else:
 | 
						|
            print(f"Timing {fn.__name__} ...", end=" ")
 | 
						|
 | 
						|
    # warmup
 | 
						|
    for _ in range(5):
 | 
						|
        mx.eval(fn(*args, **kwargs))
 | 
						|
 | 
						|
    num_iters = 100
 | 
						|
    tic = time.perf_counter()
 | 
						|
    for _ in range(num_iters):
 | 
						|
        x = mx.eval(fn(*args, **kwargs))
 | 
						|
    toc = time.perf_counter()
 | 
						|
 | 
						|
    msec = 1e3 * (toc - tic) / num_iters
 | 
						|
    if world.rank() == 0:
 | 
						|
        print(f"{msec:.5f} msec")
 | 
						|
 | 
						|
 | 
						|
def time_all_sum():
 | 
						|
    shape = (4096,)
 | 
						|
    x = mx.random.uniform(shape=shape)
 | 
						|
    mx.eval(x)
 | 
						|
 | 
						|
    def sine(x):
 | 
						|
        for _ in range(20):
 | 
						|
            x = mx.sin(x)
 | 
						|
        return x
 | 
						|
 | 
						|
    time_fn(sine, x)
 | 
						|
 | 
						|
    def all_sum_plain(x):
 | 
						|
        for _ in range(20):
 | 
						|
            x = mx.distributed.all_sum(x)
 | 
						|
        return x
 | 
						|
 | 
						|
    time_fn(all_sum_plain, x)
 | 
						|
 | 
						|
    def all_sum_with_sine(x):
 | 
						|
        for _ in range(20):
 | 
						|
            x = mx.sin(x)
 | 
						|
            x = mx.distributed.all_sum(x)
 | 
						|
        return x
 | 
						|
 | 
						|
    time_fn(all_sum_with_sine, x)
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    time_all_sum()
 |