mpi send use input as output (#1750)

* mpi send use input as output

* move earlier
This commit is contained in:
Awni Hannun 2025-01-06 06:08:43 -08:00 committed by GitHub
parent eab93985b8
commit 058d6ce683
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 19 additions and 3 deletions

View File

@ -3,6 +3,7 @@
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/device.h"
#include "mlx/distributed/ops.h"
#include "mlx/distributed/primitives.h"
@ -89,13 +90,14 @@ void Send::eval_gpu(
auto& in = inputs[0];
auto& out = outputs[0];
move_or_copy(in, out);
// Schedule an async send on the comm stream
auto task = [in = in, out = out, group = group(), dst = dst_]() mutable {
if (in.event().valid()) {
in.event().wait();
}
distributed::detail::send(group, in, dst);
distributed::detail::send(group, out, dst);
out.event().signal();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
@ -133,6 +135,7 @@ void Recv::eval_gpu(
// Encode a wait event as there is no input for the recv to encode a signal.
auto& s = stream();
auto& d = metal::device(s.device);
d.end_encoding(s.index);
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->encodeWait(
static_cast<MTL::Event*>(out.event().raw_event().get()),

View File

@ -78,7 +78,10 @@ array send(
}
return array(
{0}, int32, std::make_shared<Send>(to_stream(s), group, dst), {x});
x.shape(),
x.dtype(),
std::make_shared<Send>(to_stream(s), group, dst),
{x});
}
array recv(

View File

@ -3,6 +3,7 @@
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/backend/common/utils.h"
#include "mlx/distributed/ops.h"
#include "mlx/distributed/primitives.h"
#include "mlx/ops.h"
@ -105,6 +106,7 @@ void Send::eval_cpu(
assert(outputs.size() == 1);
distributed::detail::send(group(), inputs[0], dst_);
move_or_copy(inputs[0], outputs[0]);
}
std::pair<std::vector<array>, std::vector<int>> Send::vmap(

View File

@ -148,7 +148,7 @@ void init_distributed(nb::module_& parent_module) {
in which case the default stream of the default device is used.
Returns:
array: An empty array which when evaluated the send is performed.
array: An array identical to ``x`` which when evaluated the send is performed.
)pbdoc");
m.def(

View File

@ -111,6 +111,14 @@ class TestDistributed(mlx_tests.MLXTestCase):
self.assertTrue(mx.all(x == (1024 if pairs.rank() == 0 else 512)))
# Check recv and computation in same eval:
y = mx.ones((5, 5)) + mx.array(2.0)
if send:
x = mx.distributed.send(2 * x, neighbor, group=pairs)
else:
x = mx.distributed.recv_like(x, neighbor, group=pairs)
mx.eval(y, x)
def test_average_gradients(self):
original_all_sum = mx.distributed.all_sum
n_calls = 0