mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Adds send/recv ops in distributed (#1366)
This commit is contained in:

committed by
GitHub

parent
1d94ac3f90
commit
cdb59faea6
@@ -93,6 +93,23 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertTrue(mx.all(z == z_target))
|
||||
|
||||
def test_send_recv(self):
|
||||
world = mx.distributed.init()
|
||||
pairs = world.split(world.rank() // 2)
|
||||
neighbor = (pairs.rank() + 1) % 2
|
||||
send = pairs.rank() == 0
|
||||
|
||||
x = mx.ones(10)
|
||||
for i in range(10):
|
||||
if send:
|
||||
mx.eval(mx.distributed.send(2 * x, neighbor, group=pairs))
|
||||
else:
|
||||
x = mx.distributed.recv_like(x, neighbor, group=pairs)
|
||||
mx.eval(x)
|
||||
send = not send
|
||||
|
||||
self.assertTrue(mx.all(x == (1024 if pairs.rank() == 0 else 512)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user