Add a ring all gather (#1985)

This commit is contained in:
Angelos Katharopoulos 2025-03-21 13:36:51 -07:00 committed by GitHub
parent 25814a9458
commit 69e4dd506b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 91 additions and 6 deletions

View File

@ -603,8 +603,8 @@ class RingGroup : public GroupImpl {
return size_; return size_;
} }
void all_sum(const array& input_, array& output, Stream stream) override { void all_sum(const array& input, array& output, Stream stream) override {
SWITCH_TYPE(output, all_sum<T>(input_, output, stream)); SWITCH_TYPE(output, all_sum<T>(input, output, stream));
} }
std::shared_ptr<GroupImpl> split(int color, int key = -1) override { std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
@ -612,7 +612,39 @@ class RingGroup : public GroupImpl {
} }
void all_gather(const array& input, array& output, Stream stream) override { void all_gather(const array& input, array& output, Stream stream) override {
throw std::runtime_error("[ring] All gather not supported."); auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(input);
encoder.set_output_array(output);
encoder.dispatch([input_ptr = input.data<char>(),
nbytes = input.nbytes(),
output_ptr = output.data<char>(),
this]() {
constexpr size_t min_send_size = 262144;
size_t n_gathers = std::max(
std::min(
sockets_right_.size() + sockets_left_.size(),
nbytes / min_send_size),
1UL);
size_t bytes_per_gather = ceildiv(nbytes, n_gathers);
std::vector<std::future<void>> all_gathers;
for (int i = 0; i < n_gathers; i++) {
auto offset = i * bytes_per_gather;
all_gathers.emplace_back(pool_.enqueue(std::bind(
&RingGroup::all_gather_impl,
this,
input_ptr + offset,
output_ptr + offset,
nbytes,
offset + bytes_per_gather > nbytes ? nbytes - offset
: bytes_per_gather,
sockets_right_[i / 2],
sockets_left_[i / 2],
(i % 2) ? -1 : 1)));
}
for (auto& f : all_gathers) {
f.wait();
}
});
} }
void send(const array& input, int dst, Stream stream) override { void send(const array& input, int dst, Stream stream) override {
@ -642,8 +674,7 @@ class RingGroup : public GroupImpl {
encoder.dispatch( encoder.dispatch(
[out_ptr = out.data<char>(), nbytes = out.nbytes(), src, this]() { [out_ptr = out.data<char>(), nbytes = out.nbytes(), src, this]() {
// NOTE: We 'll check the sockets with the opposite order of send so // NOTE: We 'll check the sockets with the opposite order of send so
// that // that they work even with 2 nodes where left and right is the same
// they work even with 2 nodes where left and right is the same
// neighbor. // neighbor.
int right = (rank_ + 1) % size_; int right = (rank_ + 1) % size_;
int left = (rank_ + size_ - 1) % size_; int left = (rank_ + size_ - 1) % size_;
@ -827,6 +858,42 @@ class RingGroup : public GroupImpl {
recvs[b].wait(); recvs[b].wait();
} }
void all_gather_impl(
const char* input,
char* output,
size_t input_size,
size_t data_size,
int socket_right,
int socket_left,
int direction) {
// Choose which socket we send to and recv from
int socket_send = (direction < 0) ? socket_right : socket_left;
int socket_recv = (direction < 0) ? socket_left : socket_right;
// Initial segments
int send_segment = rank_;
int recv_segment = (rank_ + direction + size_) % size_;
// Copy our own segment in the output
std::memcpy(output + rank_ * input_size, input, data_size);
// Simple send/recv all gather. Possible performance improvement by
// splitting to multiple chunks and allowing send/recv to run a bit ahead.
// See all_sum_impl for an example.
for (int i = 0; i < size_ - 1; i++) {
auto sent = comm_.send(
socket_send, output + send_segment * input_size, data_size);
auto recvd = comm_.recv(
socket_recv, output + recv_segment * input_size, data_size);
send_segment = (send_segment + size_ + direction) % size_;
recv_segment = (recv_segment + size_ + direction) % size_;
sent.wait();
recvd.wait();
}
}
void void
send(const std::vector<int>& sockets, const char* data, size_t data_size) { send(const std::vector<int>& sockets, const char* data, size_t data_size) {
size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size())); size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size()));

View File

@ -56,6 +56,24 @@ class TestRingDistributed(mlx_tests.MLXTestCase):
maxrelerror = ((y - z).abs() / z.abs()).max() maxrelerror = ((y - z).abs() / z.abs()).max()
self.assertLessEqual(maxrelerror, rtol) self.assertLessEqual(maxrelerror, rtol)
def test_all_gather(self):
world = mx.distributed.init()
dtypes = [
mx.int8,
mx.uint8,
mx.int16,
mx.uint16,
mx.int32,
mx.uint32,
mx.float32,
mx.complex64,
]
for dt in dtypes:
x = mx.ones((2, 2, 4), dtype=dt)
y = mx.distributed.all_gather(x)
self.assertEqual(y.shape, (world.size() * 2, 2, 4))
self.assertTrue(mx.all(y == 1))
def test_send_recv(self): def test_send_recv(self):
world = mx.distributed.init() world = mx.distributed.init()
dtypes = [ dtypes = [