diff --git a/mlx/distributed/ring/ring.cpp b/mlx/distributed/ring/ring.cpp index d8fe53051..a3f7e029d 100644 --- a/mlx/distributed/ring/ring.cpp +++ b/mlx/distributed/ring/ring.cpp @@ -603,8 +603,8 @@ class RingGroup : public GroupImpl { return size_; } - void all_sum(const array& input_, array& output, Stream stream) override { - SWITCH_TYPE(output, all_sum(input_, output, stream)); + void all_sum(const array& input, array& output, Stream stream) override { + SWITCH_TYPE(output, all_sum(input, output, stream)); } std::shared_ptr split(int color, int key = -1) override { @@ -612,7 +612,39 @@ class RingGroup : public GroupImpl { } void all_gather(const array& input, array& output, Stream stream) override { - throw std::runtime_error("[ring] All gather not supported."); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(input); + encoder.set_output_array(output); + encoder.dispatch([input_ptr = input.data(), + nbytes = input.nbytes(), + output_ptr = output.data(), + this]() { + constexpr size_t min_send_size = 262144; + size_t n_gathers = std::max( + std::min( + sockets_right_.size() + sockets_left_.size(), + nbytes / min_send_size), + 1UL); + size_t bytes_per_gather = ceildiv(nbytes, n_gathers); + std::vector> all_gathers; + for (int i = 0; i < n_gathers; i++) { + auto offset = i * bytes_per_gather; + all_gathers.emplace_back(pool_.enqueue(std::bind( + &RingGroup::all_gather_impl, + this, + input_ptr + offset, + output_ptr + offset, + nbytes, + offset + bytes_per_gather > nbytes ? nbytes - offset + : bytes_per_gather, + sockets_right_[i / 2], + sockets_left_[i / 2], + (i % 2) ? -1 : 1))); + } + for (auto& f : all_gathers) { + f.wait(); + } + }); } void send(const array& input, int dst, Stream stream) override { @@ -642,9 +674,8 @@ class RingGroup : public GroupImpl { encoder.dispatch( [out_ptr = out.data(), nbytes = out.nbytes(), src, this]() { // NOTE: We 'll check the sockets with the opposite order of send so - // that - // they work even with 2 nodes where left and right is the same - // neighbor. + // that they work even with 2 nodes where left and right is the same + // neighbor. int right = (rank_ + 1) % size_; int left = (rank_ + size_ - 1) % size_; if (src == left) { @@ -827,6 +858,42 @@ class RingGroup : public GroupImpl { recvs[b].wait(); } + void all_gather_impl( + const char* input, + char* output, + size_t input_size, + size_t data_size, + int socket_right, + int socket_left, + int direction) { + // Choose which socket we send to and recv from + int socket_send = (direction < 0) ? socket_right : socket_left; + int socket_recv = (direction < 0) ? socket_left : socket_right; + + // Initial segments + int send_segment = rank_; + int recv_segment = (rank_ + direction + size_) % size_; + + // Copy our own segment in the output + std::memcpy(output + rank_ * input_size, input, data_size); + + // Simple send/recv all gather. Possible performance improvement by + // splitting to multiple chunks and allowing send/recv to run a bit ahead. + // See all_sum_impl for an example. + for (int i = 0; i < size_ - 1; i++) { + auto sent = comm_.send( + socket_send, output + send_segment * input_size, data_size); + auto recvd = comm_.recv( + socket_recv, output + recv_segment * input_size, data_size); + + send_segment = (send_segment + size_ + direction) % size_; + recv_segment = (recv_segment + size_ + direction) % size_; + + sent.wait(); + recvd.wait(); + } + } + void send(const std::vector& sockets, const char* data, size_t data_size) { size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size())); diff --git a/python/tests/ring_test_distributed.py b/python/tests/ring_test_distributed.py index 0c68914bf..93039095e 100644 --- a/python/tests/ring_test_distributed.py +++ b/python/tests/ring_test_distributed.py @@ -56,6 +56,24 @@ class TestRingDistributed(mlx_tests.MLXTestCase): maxrelerror = ((y - z).abs() / z.abs()).max() self.assertLessEqual(maxrelerror, rtol) + def test_all_gather(self): + world = mx.distributed.init() + dtypes = [ + mx.int8, + mx.uint8, + mx.int16, + mx.uint16, + mx.int32, + mx.uint32, + mx.float32, + mx.complex64, + ] + for dt in dtypes: + x = mx.ones((2, 2, 4), dtype=dt) + y = mx.distributed.all_gather(x) + self.assertEqual(y.shape, (world.size() * 2, 2, 4)) + self.assertTrue(mx.all(y == 1)) + def test_send_recv(self): world = mx.distributed.init() dtypes = [