import time import mlx.core as mx rank = mx.distributed.init().rank() def timeit(fn, a): # warmup for _ in range(5): mx.eval(fn(a)) its = 10 tic = time.perf_counter() for _ in range(its): mx.eval(fn(a)) toc = time.perf_counter() ms = 1000 * (toc - tic) / its return ms def all_reduce_benchmark(): a = mx.ones((5, 5), mx.int32) its_per_eval = 100 def fn(x): for _ in range(its_per_eval): x = mx.distributed.all_sum(x) x = x - 1 return x ms = timeit(fn, a) / its_per_eval if rank == 0: print(f"All Reduce: time per iteration {ms:.6f} (ms)") def all_gather_benchmark(): a = mx.ones((5, 5), mx.int32) its_per_eval = 100 def fn(x): for _ in range(its_per_eval): x = mx.distributed.all_gather(x)[0] return x ms = timeit(fn, a) / its_per_eval if rank == 0: print(f"All gather: time per iteration {ms:.6f} (ms)") if __name__ == "__main__": all_reduce_benchmark() all_gather_benchmark()