From 058d6ce6830c5a3b48b315ba66c192211e989779 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 6 Jan 2025 06:08:43 -0800 Subject: [PATCH] mpi send use input as output (#1750) * mpi send use input as output * move earlier --- mlx/backend/metal/distributed.cpp | 5 ++++- mlx/distributed/ops.cpp | 5 ++++- mlx/distributed/primitives.cpp | 2 ++ python/src/distributed.cpp | 2 +- python/tests/mpi_test_distributed.py | 8 ++++++++ 5 files changed, 19 insertions(+), 3 deletions(-) diff --git a/mlx/backend/metal/distributed.cpp b/mlx/backend/metal/distributed.cpp index 4cbd56af6..98d484c10 100644 --- a/mlx/backend/metal/distributed.cpp +++ b/mlx/backend/metal/distributed.cpp @@ -3,6 +3,7 @@ #include #include "mlx/allocator.h" +#include "mlx/backend/common/utils.h" #include "mlx/backend/metal/device.h" #include "mlx/distributed/ops.h" #include "mlx/distributed/primitives.h" @@ -89,13 +90,14 @@ void Send::eval_gpu( auto& in = inputs[0]; auto& out = outputs[0]; + move_or_copy(in, out); // Schedule an async send on the comm stream auto task = [in = in, out = out, group = group(), dst = dst_]() mutable { if (in.event().valid()) { in.event().wait(); } - distributed::detail::send(group, in, dst); + distributed::detail::send(group, out, dst); out.event().signal(); }; scheduler::enqueue(detail::communication_stream(), std::move(task)); @@ -133,6 +135,7 @@ void Recv::eval_gpu( // Encode a wait event as there is no input for the recv to encode a signal. auto& s = stream(); auto& d = metal::device(s.device); + d.end_encoding(s.index); auto command_buffer = d.get_command_buffer(s.index); command_buffer->encodeWait( static_cast(out.event().raw_event().get()), diff --git a/mlx/distributed/ops.cpp b/mlx/distributed/ops.cpp index 8f3148778..2af552664 100644 --- a/mlx/distributed/ops.cpp +++ b/mlx/distributed/ops.cpp @@ -78,7 +78,10 @@ array send( } return array( - {0}, int32, std::make_shared(to_stream(s), group, dst), {x}); + x.shape(), + x.dtype(), + std::make_shared(to_stream(s), group, dst), + {x}); } array recv( diff --git a/mlx/distributed/primitives.cpp b/mlx/distributed/primitives.cpp index 84ce86ffa..2a151cc0c 100644 --- a/mlx/distributed/primitives.cpp +++ b/mlx/distributed/primitives.cpp @@ -3,6 +3,7 @@ #include #include "mlx/allocator.h" +#include "mlx/backend/common/utils.h" #include "mlx/distributed/ops.h" #include "mlx/distributed/primitives.h" #include "mlx/ops.h" @@ -105,6 +106,7 @@ void Send::eval_cpu( assert(outputs.size() == 1); distributed::detail::send(group(), inputs[0], dst_); + move_or_copy(inputs[0], outputs[0]); } std::pair, std::vector> Send::vmap( diff --git a/python/src/distributed.cpp b/python/src/distributed.cpp index ebce7acb5..f0459b8d3 100644 --- a/python/src/distributed.cpp +++ b/python/src/distributed.cpp @@ -148,7 +148,7 @@ void init_distributed(nb::module_& parent_module) { in which case the default stream of the default device is used. Returns: - array: An empty array which when evaluated the send is performed. + array: An array identical to ``x`` which when evaluated the send is performed. )pbdoc"); m.def( diff --git a/python/tests/mpi_test_distributed.py b/python/tests/mpi_test_distributed.py index aa261a6e5..2af7fcf9a 100644 --- a/python/tests/mpi_test_distributed.py +++ b/python/tests/mpi_test_distributed.py @@ -111,6 +111,14 @@ class TestDistributed(mlx_tests.MLXTestCase): self.assertTrue(mx.all(x == (1024 if pairs.rank() == 0 else 512))) + # Check recv and computation in same eval: + y = mx.ones((5, 5)) + mx.array(2.0) + if send: + x = mx.distributed.send(2 * x, neighbor, group=pairs) + else: + x = mx.distributed.recv_like(x, neighbor, group=pairs) + mx.eval(y, x) + def test_average_gradients(self): original_all_sum = mx.distributed.all_sum n_calls = 0