mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
MPI ops in GPU stream for faster comms (#1356)
This commit is contained in:
66
benchmarks/python/distributed_bench.py
Normal file
66
benchmarks/python/distributed_bench.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
"""
|
||||
Run with:
|
||||
mpirun -n 2 python /path/to/distributed_bench.py
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def time_fn(fn, *args, **kwargs):
|
||||
msg = kwargs.pop("msg", None)
|
||||
world = mx.distributed.init()
|
||||
if world.rank() == 0:
|
||||
if msg:
|
||||
print(f"Timing {msg} ...", end=" ")
|
||||
else:
|
||||
print(f"Timing {fn.__name__} ...", end=" ")
|
||||
|
||||
# warmup
|
||||
for _ in range(5):
|
||||
mx.eval(fn(*args, **kwargs))
|
||||
|
||||
num_iters = 100
|
||||
tic = time.perf_counter()
|
||||
for _ in range(num_iters):
|
||||
x = mx.eval(fn(*args, **kwargs))
|
||||
toc = time.perf_counter()
|
||||
|
||||
msec = 1e3 * (toc - tic) / num_iters
|
||||
if world.rank() == 0:
|
||||
print(f"{msec:.5f} msec")
|
||||
|
||||
|
||||
def time_all_sum():
|
||||
shape = (4096,)
|
||||
x = mx.random.uniform(shape=shape)
|
||||
mx.eval(x)
|
||||
|
||||
def sine(x):
|
||||
for _ in range(20):
|
||||
x = mx.sin(x)
|
||||
return x
|
||||
|
||||
time_fn(sine, x)
|
||||
|
||||
def all_sum_plain(x):
|
||||
for _ in range(20):
|
||||
x = mx.distributed.all_sum(x)
|
||||
return x
|
||||
|
||||
time_fn(all_sum_plain, x)
|
||||
|
||||
def all_sum_with_sine(x):
|
||||
for _ in range(20):
|
||||
x = mx.sin(x)
|
||||
x = mx.distributed.all_sum(x)
|
||||
return x
|
||||
|
||||
time_fn(all_sum_with_sine, x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_all_sum()
|
Reference in New Issue
Block a user