mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Ring update (#1885)
This commit is contained in:

committed by
GitHub

parent
0ebc8a3d25
commit
10b271d963
@@ -56,6 +56,45 @@ class TestRingDistributed(mlx_tests.MLXTestCase):
|
||||
maxrelerror = ((y - z).abs() / z.abs()).max()
|
||||
self.assertLessEqual(maxrelerror, rtol)
|
||||
|
||||
def test_send_recv(self):
|
||||
world = mx.distributed.init()
|
||||
dtypes = [
|
||||
mx.int8,
|
||||
mx.uint8,
|
||||
mx.int16,
|
||||
mx.uint16,
|
||||
mx.int32,
|
||||
mx.uint32,
|
||||
mx.float32,
|
||||
mx.float16,
|
||||
mx.bfloat16,
|
||||
mx.complex64,
|
||||
]
|
||||
sizes = [
|
||||
(7,),
|
||||
(10,),
|
||||
(1024,),
|
||||
(1024, 1024),
|
||||
]
|
||||
key = mx.random.key(0)
|
||||
right = (world.rank() + 1) % world.size()
|
||||
left = (world.rank() + world.size() - 1) % world.size()
|
||||
for dt in dtypes:
|
||||
for sh in sizes:
|
||||
x = (
|
||||
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
|
||||
).astype(dt)
|
||||
if world.rank() % 2 == 0:
|
||||
y = mx.distributed.send(x[world.rank()], right)
|
||||
z = mx.distributed.recv_like(y, left)
|
||||
mx.eval(y, z)
|
||||
else:
|
||||
z = mx.distributed.recv_like(x[world.rank()], left)
|
||||
y = mx.distributed.send(x[world.rank()], right)
|
||||
mx.eval(z, y)
|
||||
self.assertTrue(mx.all(y == x[world.rank()]))
|
||||
self.assertTrue(mx.all(z == x[left]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user