Add a ring all gather (#1985)

This commit is contained in:
Angelos Katharopoulos
2025-03-21 13:36:51 -07:00
committed by GitHub
parent 25814a9458
commit 69e4dd506b
2 changed files with 91 additions and 6 deletions

View File

@@ -56,6 +56,24 @@ class TestRingDistributed(mlx_tests.MLXTestCase):
maxrelerror = ((y - z).abs() / z.abs()).max()
self.assertLessEqual(maxrelerror, rtol)
def test_all_gather(self):
world = mx.distributed.init()
dtypes = [
mx.int8,
mx.uint8,
mx.int16,
mx.uint16,
mx.int32,
mx.uint32,
mx.float32,
mx.complex64,
]
for dt in dtypes:
x = mx.ones((2, 2, 4), dtype=dt)
y = mx.distributed.all_gather(x)
self.assertEqual(y.shape, (world.size() * 2, 2, 4))
self.assertTrue(mx.all(y == 1))
def test_send_recv(self):
world = mx.distributed.init()
dtypes = [