mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-02 16:56:46 +08:00
mpi send use input as output (#1750)
* mpi send use input as output * move earlier
This commit is contained in:
parent
eab93985b8
commit
058d6ce683
@ -3,6 +3,7 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/distributed/ops.h"
|
#include "mlx/distributed/ops.h"
|
||||||
#include "mlx/distributed/primitives.h"
|
#include "mlx/distributed/primitives.h"
|
||||||
@ -89,13 +90,14 @@ void Send::eval_gpu(
|
|||||||
|
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
auto& out = outputs[0];
|
auto& out = outputs[0];
|
||||||
|
move_or_copy(in, out);
|
||||||
|
|
||||||
// Schedule an async send on the comm stream
|
// Schedule an async send on the comm stream
|
||||||
auto task = [in = in, out = out, group = group(), dst = dst_]() mutable {
|
auto task = [in = in, out = out, group = group(), dst = dst_]() mutable {
|
||||||
if (in.event().valid()) {
|
if (in.event().valid()) {
|
||||||
in.event().wait();
|
in.event().wait();
|
||||||
}
|
}
|
||||||
distributed::detail::send(group, in, dst);
|
distributed::detail::send(group, out, dst);
|
||||||
out.event().signal();
|
out.event().signal();
|
||||||
};
|
};
|
||||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||||
@ -133,6 +135,7 @@ void Recv::eval_gpu(
|
|||||||
// Encode a wait event as there is no input for the recv to encode a signal.
|
// Encode a wait event as there is no input for the recv to encode a signal.
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
d.end_encoding(s.index);
|
||||||
auto command_buffer = d.get_command_buffer(s.index);
|
auto command_buffer = d.get_command_buffer(s.index);
|
||||||
command_buffer->encodeWait(
|
command_buffer->encodeWait(
|
||||||
static_cast<MTL::Event*>(out.event().raw_event().get()),
|
static_cast<MTL::Event*>(out.event().raw_event().get()),
|
||||||
|
@ -78,7 +78,10 @@ array send(
|
|||||||
}
|
}
|
||||||
|
|
||||||
return array(
|
return array(
|
||||||
{0}, int32, std::make_shared<Send>(to_stream(s), group, dst), {x});
|
x.shape(),
|
||||||
|
x.dtype(),
|
||||||
|
std::make_shared<Send>(to_stream(s), group, dst),
|
||||||
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
array recv(
|
array recv(
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/distributed/ops.h"
|
#include "mlx/distributed/ops.h"
|
||||||
#include "mlx/distributed/primitives.h"
|
#include "mlx/distributed/primitives.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
@ -105,6 +106,7 @@ void Send::eval_cpu(
|
|||||||
assert(outputs.size() == 1);
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
distributed::detail::send(group(), inputs[0], dst_);
|
distributed::detail::send(group(), inputs[0], dst_);
|
||||||
|
move_or_copy(inputs[0], outputs[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> Send::vmap(
|
std::pair<std::vector<array>, std::vector<int>> Send::vmap(
|
||||||
|
@ -148,7 +148,7 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
in which case the default stream of the default device is used.
|
in which case the default stream of the default device is used.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: An empty array which when evaluated the send is performed.
|
array: An array identical to ``x`` which when evaluated the send is performed.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
|
@ -111,6 +111,14 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertTrue(mx.all(x == (1024 if pairs.rank() == 0 else 512)))
|
self.assertTrue(mx.all(x == (1024 if pairs.rank() == 0 else 512)))
|
||||||
|
|
||||||
|
# Check recv and computation in same eval:
|
||||||
|
y = mx.ones((5, 5)) + mx.array(2.0)
|
||||||
|
if send:
|
||||||
|
x = mx.distributed.send(2 * x, neighbor, group=pairs)
|
||||||
|
else:
|
||||||
|
x = mx.distributed.recv_like(x, neighbor, group=pairs)
|
||||||
|
mx.eval(y, x)
|
||||||
|
|
||||||
def test_average_gradients(self):
|
def test_average_gradients(self):
|
||||||
original_all_sum = mx.distributed.all_sum
|
original_all_sum = mx.distributed.all_sum
|
||||||
n_calls = 0
|
n_calls = 0
|
||||||
|
Loading…
Reference in New Issue
Block a user