Nccl reduce scatter, all gather (#2727)

* Added reduce scatter and all gather for nccl

* fix unused import, delete unused file

* small fix

* deleted useless condition

* fixed comments

* fix bug in eval_gpu, renamed to sum_scatter, fix docs

* final fix docs

* remove and

* Update mlx/distributed/mpi/mpi.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* fix broken set input output

* fixes set output

* typo

* fix typo

* no cpu, no gpu for reduce scatter

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Anastasiia Filippova
2025-11-05 17:21:11 +01:00
committed by GitHub
parent 761f901a41
commit 27778156dc
19 changed files with 351 additions and 311 deletions

View File

@@ -1,7 +1,5 @@
# Copyright © 2024 Apple Inc.
import unittest
import mlx.core as mx
import mlx_distributed_tests
import mlx_tests
@@ -10,7 +8,9 @@ import mlx_tests
class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
@classmethod
def setUpClass(cls):
world = mx.distributed.init(strict=True, backend="ring")
_ = mx.distributed.init(strict=True, backend="ring")
cls.atol = 1e-6
cls.rtol = 1e-4
def test_groups(self):
world = mx.distributed.init()
@@ -24,18 +24,11 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
with self.assertRaises(RuntimeError):
sub = world.split(world.rank() % 2)
def test_all_reduce(self):
def test_all_reduce_extra(self):
world = mx.distributed.init()
dtypes = [
(mx.int8, 0),
(mx.uint8, 0),
(mx.int16, 0),
(mx.uint16, 0),
(mx.int32, 0),
(mx.uint32, 0),
(mx.float32, 1e-6),
(mx.float16, 5e-3),
(mx.bfloat16, 1e-1),
(mx.complex64, 1e-6),
]
sizes = [
@@ -45,7 +38,6 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
(1024, 1024),
]
key = mx.random.key(0)
reductions = ["min", "max", "sum"]
for dt, rtol in dtypes:
for sh in sizes:
@@ -72,16 +64,11 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
z = x.min(0)
self.assertTrue(mx.all(y == z))
def test_all_gather(self):
def test_all_gather_extra(self):
world = mx.distributed.init()
dtypes = [
mx.int8,
mx.uint8,
mx.int16,
mx.uint16,
mx.int32,
mx.uint32,
mx.float32,
mx.complex64,
]
for dt in dtypes: