mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Add a ring all gather (#1985)
This commit is contained in:
parent
25814a9458
commit
69e4dd506b
@ -603,8 +603,8 @@ class RingGroup : public GroupImpl {
|
||||
return size_;
|
||||
}
|
||||
|
||||
void all_sum(const array& input_, array& output, Stream stream) override {
|
||||
SWITCH_TYPE(output, all_sum<T>(input_, output, stream));
|
||||
void all_sum(const array& input, array& output, Stream stream) override {
|
||||
SWITCH_TYPE(output, all_sum<T>(input, output, stream));
|
||||
}
|
||||
|
||||
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||
@ -612,7 +612,39 @@ class RingGroup : public GroupImpl {
|
||||
}
|
||||
|
||||
void all_gather(const array& input, array& output, Stream stream) override {
|
||||
throw std::runtime_error("[ring] All gather not supported.");
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(input);
|
||||
encoder.set_output_array(output);
|
||||
encoder.dispatch([input_ptr = input.data<char>(),
|
||||
nbytes = input.nbytes(),
|
||||
output_ptr = output.data<char>(),
|
||||
this]() {
|
||||
constexpr size_t min_send_size = 262144;
|
||||
size_t n_gathers = std::max(
|
||||
std::min(
|
||||
sockets_right_.size() + sockets_left_.size(),
|
||||
nbytes / min_send_size),
|
||||
1UL);
|
||||
size_t bytes_per_gather = ceildiv(nbytes, n_gathers);
|
||||
std::vector<std::future<void>> all_gathers;
|
||||
for (int i = 0; i < n_gathers; i++) {
|
||||
auto offset = i * bytes_per_gather;
|
||||
all_gathers.emplace_back(pool_.enqueue(std::bind(
|
||||
&RingGroup::all_gather_impl,
|
||||
this,
|
||||
input_ptr + offset,
|
||||
output_ptr + offset,
|
||||
nbytes,
|
||||
offset + bytes_per_gather > nbytes ? nbytes - offset
|
||||
: bytes_per_gather,
|
||||
sockets_right_[i / 2],
|
||||
sockets_left_[i / 2],
|
||||
(i % 2) ? -1 : 1)));
|
||||
}
|
||||
for (auto& f : all_gathers) {
|
||||
f.wait();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void send(const array& input, int dst, Stream stream) override {
|
||||
@ -642,8 +674,7 @@ class RingGroup : public GroupImpl {
|
||||
encoder.dispatch(
|
||||
[out_ptr = out.data<char>(), nbytes = out.nbytes(), src, this]() {
|
||||
// NOTE: We 'll check the sockets with the opposite order of send so
|
||||
// that
|
||||
// they work even with 2 nodes where left and right is the same
|
||||
// that they work even with 2 nodes where left and right is the same
|
||||
// neighbor.
|
||||
int right = (rank_ + 1) % size_;
|
||||
int left = (rank_ + size_ - 1) % size_;
|
||||
@ -827,6 +858,42 @@ class RingGroup : public GroupImpl {
|
||||
recvs[b].wait();
|
||||
}
|
||||
|
||||
void all_gather_impl(
|
||||
const char* input,
|
||||
char* output,
|
||||
size_t input_size,
|
||||
size_t data_size,
|
||||
int socket_right,
|
||||
int socket_left,
|
||||
int direction) {
|
||||
// Choose which socket we send to and recv from
|
||||
int socket_send = (direction < 0) ? socket_right : socket_left;
|
||||
int socket_recv = (direction < 0) ? socket_left : socket_right;
|
||||
|
||||
// Initial segments
|
||||
int send_segment = rank_;
|
||||
int recv_segment = (rank_ + direction + size_) % size_;
|
||||
|
||||
// Copy our own segment in the output
|
||||
std::memcpy(output + rank_ * input_size, input, data_size);
|
||||
|
||||
// Simple send/recv all gather. Possible performance improvement by
|
||||
// splitting to multiple chunks and allowing send/recv to run a bit ahead.
|
||||
// See all_sum_impl for an example.
|
||||
for (int i = 0; i < size_ - 1; i++) {
|
||||
auto sent = comm_.send(
|
||||
socket_send, output + send_segment * input_size, data_size);
|
||||
auto recvd = comm_.recv(
|
||||
socket_recv, output + recv_segment * input_size, data_size);
|
||||
|
||||
send_segment = (send_segment + size_ + direction) % size_;
|
||||
recv_segment = (recv_segment + size_ + direction) % size_;
|
||||
|
||||
sent.wait();
|
||||
recvd.wait();
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
send(const std::vector<int>& sockets, const char* data, size_t data_size) {
|
||||
size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size()));
|
||||
|
@ -56,6 +56,24 @@ class TestRingDistributed(mlx_tests.MLXTestCase):
|
||||
maxrelerror = ((y - z).abs() / z.abs()).max()
|
||||
self.assertLessEqual(maxrelerror, rtol)
|
||||
|
||||
def test_all_gather(self):
|
||||
world = mx.distributed.init()
|
||||
dtypes = [
|
||||
mx.int8,
|
||||
mx.uint8,
|
||||
mx.int16,
|
||||
mx.uint16,
|
||||
mx.int32,
|
||||
mx.uint32,
|
||||
mx.float32,
|
||||
mx.complex64,
|
||||
]
|
||||
for dt in dtypes:
|
||||
x = mx.ones((2, 2, 4), dtype=dt)
|
||||
y = mx.distributed.all_gather(x)
|
||||
self.assertEqual(y.shape, (world.size() * 2, 2, 4))
|
||||
self.assertTrue(mx.all(y == 1))
|
||||
|
||||
def test_send_recv(self):
|
||||
world = mx.distributed.init()
|
||||
dtypes = [
|
||||
|
Loading…
Reference in New Issue
Block a user