mpi send use input as output (#1750)

* mpi send use input as output

* move earlier
This commit is contained in:
Awni Hannun
2025-01-06 06:08:43 -08:00
committed by GitHub
parent eab93985b8
commit 058d6ce683
5 changed files with 19 additions and 3 deletions

View File

@@ -148,7 +148,7 @@ void init_distributed(nb::module_& parent_module) {
in which case the default stream of the default device is used.
Returns:
array: An empty array which when evaluated the send is performed.
array: An array identical to ``x`` which when evaluated the send is performed.
)pbdoc");
m.def(

View File

@@ -111,6 +111,14 @@ class TestDistributed(mlx_tests.MLXTestCase):
self.assertTrue(mx.all(x == (1024 if pairs.rank() == 0 else 512)))
# Check recv and computation in same eval:
y = mx.ones((5, 5)) + mx.array(2.0)
if send:
x = mx.distributed.send(2 * x, neighbor, group=pairs)
else:
x = mx.distributed.recv_like(x, neighbor, group=pairs)
mx.eval(y, x)
def test_average_gradients(self):
original_all_sum = mx.distributed.all_sum
n_calls = 0