mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-14 09:07:12 +08:00
* Added reduce scatter and all gather for nccl * fix unused import, delete unused file * small fix * deleted useless condition * fixed comments * fix bug in eval_gpu, renamed to sum_scatter, fix docs * final fix docs * remove and * Update mlx/distributed/mpi/mpi.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * fix broken set input output * fixes set output * typo * fix typo * no cpu, no gpu for reduce scatter --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
53 lines
1.4 KiB
Python
53 lines
1.4 KiB
Python
# Copyright © 2024 Apple Inc.
|
|
|
|
import mlx.core as mx
|
|
import mlx_distributed_tests
|
|
import mlx_tests
|
|
|
|
|
|
class TestNCCLDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
_ = mx.distributed.init(strict=True, backend="nccl")
|
|
cls.atol = 1e-4
|
|
cls.rtol = 1e-4
|
|
|
|
def test_sum_scatter(self):
|
|
|
|
world = mx.distributed.init()
|
|
|
|
dtypes = [
|
|
(mx.float32, 1e-6),
|
|
(mx.float16, 5e-3),
|
|
(mx.bfloat16, 1e-1),
|
|
]
|
|
sizes = [
|
|
(8,),
|
|
(64,),
|
|
(1024,),
|
|
(1024, 1024),
|
|
]
|
|
key = mx.random.key(world.rank())
|
|
|
|
for dt, rtol in dtypes:
|
|
for sh in sizes:
|
|
x = (mx.random.uniform(shape=sh, key=key) * 10).astype(dt) # shape=sh
|
|
|
|
# Sum scatter
|
|
y = mx.distributed.sum_scatter(x) # shape=sh/world.size()
|
|
z = mx.distributed.all_sum(x) # shape=sh
|
|
chunk = sh[0] // world.size()
|
|
start = world.rank() * chunk
|
|
stop = start + chunk
|
|
z_ref = z[start:stop]
|
|
|
|
maxrelerror = (y - z_ref).abs()
|
|
if rtol > 0:
|
|
maxrelerror /= z_ref.abs()
|
|
maxrelerror = maxrelerror.max()
|
|
self.assertLessEqual(maxrelerror, rtol)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
mlx_tests.MLXTestRunner()
|