Adds send/recv ops in distributed (#1366)

This commit is contained in:
Angelos Katharopoulos
2024-08-26 23:01:37 -07:00
committed by GitHub
parent 1d94ac3f90
commit cdb59faea6
13 changed files with 345 additions and 19 deletions

View File

@@ -93,6 +93,23 @@ class TestDistributed(mlx_tests.MLXTestCase):
self.assertTrue(mx.all(z == z_target))
def test_send_recv(self):
world = mx.distributed.init()
pairs = world.split(world.rank() // 2)
neighbor = (pairs.rank() + 1) % 2
send = pairs.rank() == 0
x = mx.ones(10)
for i in range(10):
if send:
mx.eval(mx.distributed.send(2 * x, neighbor, group=pairs))
else:
x = mx.distributed.recv_like(x, neighbor, group=pairs)
mx.eval(x)
send = not send
self.assertTrue(mx.all(x == (1024 if pairs.rank() == 0 else 512)))
if __name__ == "__main__":
unittest.main()