Files
mlx/python/tests/nccl_test_distributed.py
Anastasiia Filippova 27778156dc Nccl reduce scatter, all gather (#2727)
* Added reduce scatter and all gather for nccl

* fix unused import, delete unused file

* small fix

* deleted useless condition

* fixed comments

* fix bug in eval_gpu, renamed to sum_scatter, fix docs

* final fix docs

* remove and

* Update mlx/distributed/mpi/mpi.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* fix broken set input output

* fixes set output

* typo

* fix typo

* no cpu, no gpu for reduce scatter

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2025-11-05 08:21:11 -08:00

53 lines
1.4 KiB
Python

# Copyright © 2024 Apple Inc.
import mlx.core as mx
import mlx_distributed_tests
import mlx_tests
class TestNCCLDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
@classmethod
def setUpClass(cls):
_ = mx.distributed.init(strict=True, backend="nccl")
cls.atol = 1e-4
cls.rtol = 1e-4
def test_sum_scatter(self):
world = mx.distributed.init()
dtypes = [
(mx.float32, 1e-6),
(mx.float16, 5e-3),
(mx.bfloat16, 1e-1),
]
sizes = [
(8,),
(64,),
(1024,),
(1024, 1024),
]
key = mx.random.key(world.rank())
for dt, rtol in dtypes:
for sh in sizes:
x = (mx.random.uniform(shape=sh, key=key) * 10).astype(dt) # shape=sh
# Sum scatter
y = mx.distributed.sum_scatter(x) # shape=sh/world.size()
z = mx.distributed.all_sum(x) # shape=sh
chunk = sh[0] // world.size()
start = world.rank() * chunk
stop = start + chunk
z_ref = z[start:stop]
maxrelerror = (y - z_ref).abs()
if rtol > 0:
maxrelerror /= z_ref.abs()
maxrelerror = maxrelerror.max()
self.assertLessEqual(maxrelerror, rtol)
if __name__ == "__main__":
mlx_tests.MLXTestRunner()