mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
mpi send use input as output (#1750)
* mpi send use input as output * move earlier
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
@@ -89,13 +90,14 @@ void Send::eval_gpu(
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
move_or_copy(in, out);
|
||||
|
||||
// Schedule an async send on the comm stream
|
||||
auto task = [in = in, out = out, group = group(), dst = dst_]() mutable {
|
||||
if (in.event().valid()) {
|
||||
in.event().wait();
|
||||
}
|
||||
distributed::detail::send(group, in, dst);
|
||||
distributed::detail::send(group, out, dst);
|
||||
out.event().signal();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
@@ -133,6 +135,7 @@ void Recv::eval_gpu(
|
||||
// Encode a wait event as there is no input for the recv to encode a signal.
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
d.end_encoding(s.index);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
command_buffer->encodeWait(
|
||||
static_cast<MTL::Event*>(out.event().raw_event().get()),
|
||||
|
Reference in New Issue
Block a user