Ring update (#1885)

This commit is contained in:
Angelos Katharopoulos
2025-02-20 14:32:31 -08:00
committed by GitHub
parent 0ebc8a3d25
commit 10b271d963
2 changed files with 418 additions and 338 deletions

View File

@@ -56,6 +56,45 @@ class TestRingDistributed(mlx_tests.MLXTestCase):
maxrelerror = ((y - z).abs() / z.abs()).max()
self.assertLessEqual(maxrelerror, rtol)
def test_send_recv(self):
world = mx.distributed.init()
dtypes = [
mx.int8,
mx.uint8,
mx.int16,
mx.uint16,
mx.int32,
mx.uint32,
mx.float32,
mx.float16,
mx.bfloat16,
mx.complex64,
]
sizes = [
(7,),
(10,),
(1024,),
(1024, 1024),
]
key = mx.random.key(0)
right = (world.rank() + 1) % world.size()
left = (world.rank() + world.size() - 1) % world.size()
for dt in dtypes:
for sh in sizes:
x = (
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
).astype(dt)
if world.rank() % 2 == 0:
y = mx.distributed.send(x[world.rank()], right)
z = mx.distributed.recv_like(y, left)
mx.eval(y, z)
else:
z = mx.distributed.recv_like(x[world.rank()], left)
y = mx.distributed.send(x[world.rank()], right)
mx.eval(z, y)
self.assertTrue(mx.all(y == x[world.rank()]))
self.assertTrue(mx.all(z == x[left]))
if __name__ == "__main__":
unittest.main()