mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-06 08:24:39 +08:00
Add a ring all gather (#1985)
This commit is contained in:

committed by
GitHub

parent
25814a9458
commit
69e4dd506b
@@ -56,6 +56,24 @@ class TestRingDistributed(mlx_tests.MLXTestCase):
|
||||
maxrelerror = ((y - z).abs() / z.abs()).max()
|
||||
self.assertLessEqual(maxrelerror, rtol)
|
||||
|
||||
def test_all_gather(self):
|
||||
world = mx.distributed.init()
|
||||
dtypes = [
|
||||
mx.int8,
|
||||
mx.uint8,
|
||||
mx.int16,
|
||||
mx.uint16,
|
||||
mx.int32,
|
||||
mx.uint32,
|
||||
mx.float32,
|
||||
mx.complex64,
|
||||
]
|
||||
for dt in dtypes:
|
||||
x = mx.ones((2, 2, 4), dtype=dt)
|
||||
y = mx.distributed.all_gather(x)
|
||||
self.assertEqual(y.shape, (world.size() * 2, 2, 4))
|
||||
self.assertTrue(mx.all(y == 1))
|
||||
|
||||
def test_send_recv(self):
|
||||
world = mx.distributed.init()
|
||||
dtypes = [
|
||||
|
Reference in New Issue
Block a user