mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-29 13:55:29 +08:00
mpi send use input as output (#1750)
* mpi send use input as output * move earlier
This commit is contained in:
parent
eab93985b8
commit
058d6ce683
@ -3,6 +3,7 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
@ -89,13 +90,14 @@ void Send::eval_gpu(
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
move_or_copy(in, out);
|
||||
|
||||
// Schedule an async send on the comm stream
|
||||
auto task = [in = in, out = out, group = group(), dst = dst_]() mutable {
|
||||
if (in.event().valid()) {
|
||||
in.event().wait();
|
||||
}
|
||||
distributed::detail::send(group, in, dst);
|
||||
distributed::detail::send(group, out, dst);
|
||||
out.event().signal();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
@ -133,6 +135,7 @@ void Recv::eval_gpu(
|
||||
// Encode a wait event as there is no input for the recv to encode a signal.
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
d.end_encoding(s.index);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
command_buffer->encodeWait(
|
||||
static_cast<MTL::Event*>(out.event().raw_event().get()),
|
||||
|
@ -78,7 +78,10 @@ array send(
|
||||
}
|
||||
|
||||
return array(
|
||||
{0}, int32, std::make_shared<Send>(to_stream(s), group, dst), {x});
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_shared<Send>(to_stream(s), group, dst),
|
||||
{x});
|
||||
}
|
||||
|
||||
array recv(
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/ops.h"
|
||||
@ -105,6 +106,7 @@ void Send::eval_cpu(
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
distributed::detail::send(group(), inputs[0], dst_);
|
||||
move_or_copy(inputs[0], outputs[0]);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Send::vmap(
|
||||
|
@ -148,7 +148,7 @@ void init_distributed(nb::module_& parent_module) {
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
array: An empty array which when evaluated the send is performed.
|
||||
array: An array identical to ``x`` which when evaluated the send is performed.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
|
@ -111,6 +111,14 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertTrue(mx.all(x == (1024 if pairs.rank() == 0 else 512)))
|
||||
|
||||
# Check recv and computation in same eval:
|
||||
y = mx.ones((5, 5)) + mx.array(2.0)
|
||||
if send:
|
||||
x = mx.distributed.send(2 * x, neighbor, group=pairs)
|
||||
else:
|
||||
x = mx.distributed.recv_like(x, neighbor, group=pairs)
|
||||
mx.eval(y, x)
|
||||
|
||||
def test_average_gradients(self):
|
||||
original_all_sum = mx.distributed.all_sum
|
||||
n_calls = 0
|
||||
|
Loading…
Reference in New Issue
Block a user