mpi send use input as output (#1750)

* mpi send use input as output

* move earlier
This commit is contained in:
Awni Hannun
2025-01-06 06:08:43 -08:00
committed by GitHub
parent eab93985b8
commit 058d6ce683
5 changed files with 19 additions and 3 deletions

View File

@@ -3,6 +3,7 @@
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/device.h"
#include "mlx/distributed/ops.h"
#include "mlx/distributed/primitives.h"
@@ -89,13 +90,14 @@ void Send::eval_gpu(
auto& in = inputs[0];
auto& out = outputs[0];
move_or_copy(in, out);
// Schedule an async send on the comm stream
auto task = [in = in, out = out, group = group(), dst = dst_]() mutable {
if (in.event().valid()) {
in.event().wait();
}
distributed::detail::send(group, in, dst);
distributed::detail::send(group, out, dst);
out.event().signal();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
@@ -133,6 +135,7 @@ void Recv::eval_gpu(
// Encode a wait event as there is no input for the recv to encode a signal.
auto& s = stream();
auto& d = metal::device(s.device);
d.end_encoding(s.index);
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->encodeWait(
static_cast<MTL::Event*>(out.event().raw_event().get()),