Files
mlx/python/tests/ring_test_distributed.py
Anastasiia Filippova 27778156dc Nccl reduce scatter, all gather (#2727)
* Added reduce scatter and all gather for nccl

* fix unused import, delete unused file

* small fix

* deleted useless condition

* fixed comments

* fix bug in eval_gpu, renamed to sum_scatter, fix docs

* final fix docs

* remove and

* Update mlx/distributed/mpi/mpi.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* fix broken set input output

* fixes set output

* typo

* fix typo

* no cpu, no gpu for reduce scatter

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2025-11-05 08:21:11 -08:00

132 lines
3.9 KiB
Python

# Copyright © 2024 Apple Inc.
import mlx.core as mx
import mlx_distributed_tests
import mlx_tests
class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
@classmethod
def setUpClass(cls):
_ = mx.distributed.init(strict=True, backend="ring")
cls.atol = 1e-6
cls.rtol = 1e-4
def test_groups(self):
world = mx.distributed.init()
self.assertEqual(world.size(), 8)
self.assertTrue(0 <= world.rank() < 8)
world2 = mx.distributed.init()
self.assertEqual(world.size(), world2.size())
self.assertEqual(world.rank(), world2.rank())
with self.assertRaises(RuntimeError):
sub = world.split(world.rank() % 2)
def test_all_reduce_extra(self):
world = mx.distributed.init()
dtypes = [
(mx.int16, 0),
(mx.uint16, 0),
(mx.complex64, 1e-6),
]
sizes = [
(7,),
(10,),
(1024,),
(1024, 1024),
]
key = mx.random.key(0)
for dt, rtol in dtypes:
for sh in sizes:
x = (
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
).astype(dt)
# All sum
y = mx.distributed.all_sum(x[world.rank()])
z = x.sum(0)
maxrelerror = (y - z).abs()
if rtol > 0:
maxrelerror /= z.abs()
maxrelerror = maxrelerror.max()
self.assertLessEqual(maxrelerror, rtol)
# All max
y = mx.distributed.all_max(x[world.rank()])
z = x.max(0)
self.assertTrue(mx.all(y == z))
# All min
y = mx.distributed.all_min(x[world.rank()])
z = x.min(0)
self.assertTrue(mx.all(y == z))
def test_all_gather_extra(self):
world = mx.distributed.init()
dtypes = [
mx.int16,
mx.uint16,
mx.complex64,
]
for dt in dtypes:
x = mx.ones((2, 2, 4), dtype=dt)
y = mx.distributed.all_gather(x)
self.assertEqual(y.shape, (world.size() * 2, 2, 4))
self.assertTrue(mx.all(y == 1))
def test_send_recv(self):
world = mx.distributed.init()
dtypes = [
mx.int8,
mx.uint8,
mx.int16,
mx.uint16,
mx.int32,
mx.uint32,
mx.float32,
mx.float16,
mx.bfloat16,
mx.complex64,
]
sizes = [
(7,),
(10,),
(1024,),
(1024, 1024),
]
key = mx.random.key(0)
right = (world.rank() + 1) % world.size()
left = (world.rank() + world.size() - 1) % world.size()
for dt in dtypes:
for sh in sizes:
x = (
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
).astype(dt)
if world.rank() % 2 == 0:
y = mx.distributed.send(x[world.rank()], right)
z = mx.distributed.recv_like(y, left)
mx.eval(y, z)
else:
z = mx.distributed.recv_like(x[world.rank()], left)
y = mx.distributed.send(x[world.rank()], right)
mx.eval(z, y)
self.assertTrue(mx.all(y == x[world.rank()]))
self.assertTrue(mx.all(z == x[left]))
def test_all_gather_vjp(self):
def fun(x):
return mx.distributed.all_gather(x)[0]
dfdx = mx.grad(fun)(mx.array(1.0))
if mx.distributed.init().rank() == 0:
self.assertEqual(dfdx.item(), 1.0)
else:
self.assertEqual(dfdx.item(), 0.0)
if __name__ == "__main__":
mlx_tests.MLXTestRunner()