From 10b271d96302865f52c9a8408a3f09e9e9f709e8 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 20 Feb 2025 14:32:31 -0800 Subject: [PATCH] Ring update (#1885) --- mlx/distributed/ring/ring.cpp | 717 ++++++++++++++------------ python/tests/ring_test_distributed.py | 39 ++ 2 files changed, 418 insertions(+), 338 deletions(-) diff --git a/mlx/distributed/ring/ring.cpp b/mlx/distributed/ring/ring.cpp index 3f1586021..2878abf9c 100644 --- a/mlx/distributed/ring/ring.cpp +++ b/mlx/distributed/ring/ring.cpp @@ -1,15 +1,19 @@ // Copyright © 2024 Apple Inc. #include +#include #include #include #include #include #include +#include #include +#include #include #include +#include #include @@ -80,53 +84,17 @@ namespace mlx::core::distributed::ring { -constexpr const size_t PACKET_SIZE = 262144; +constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024; +constexpr const size_t ALL_SUM_BUFFERS = 2; constexpr const int CONN_ATTEMPTS = 5; constexpr const int CONN_WAIT = 1000; using GroupImpl = mlx::core::distributed::detail::GroupImpl; using json = nlohmann::json; +using namespace std::chrono_literals; namespace { -class Barrier { - public: - explicit Barrier(int n_threads) - : n_threads_(n_threads), count_(0), flag_(false) {} - - void arrive_and_wait() { - std::unique_lock lock(mtx_); - - // Keep the flag that marks the current use of the barrier. The next use is - // going to have this flag flipped. - bool initial_flag = flag_; - - // Increment the count - count_++; - - // We are the last thread to arrive so reset the count, change the flag and - // notify everybody. - if (count_ == n_threads_) { - count_ = 0; - flag_ = !flag_; - cv_.notify_all(); - } - - // Wait for the rest to arrive - else { - cv_.wait(lock, [this, initial_flag]() { return initial_flag != flag_; }); - } - } - - private: - std::mutex mtx_; - std::condition_variable cv_; - int n_threads_; - - int count_; - bool flag_; // we need this for sequential use of the barrier -}; - template void log(std::ostream& os, T first) { os << first << std::endl; @@ -151,6 +119,169 @@ decltype(T() * U()) ceildiv(T a, U b) { return (a + b - 1) / b; } +class SocketThread { + public: + SocketThread(int fd) : fd_(fd), stop_(false) { + worker_ = std::thread(&SocketThread::worker, this); + int flags = fcntl(fd, F_GETFL, 0); + fcntl(fd, F_SETFL, flags | O_NONBLOCK); + } + ~SocketThread() { + stop_ = true; + condition_.notify_all(); + worker_.join(); + int flags = fcntl(fd_, F_GETFL, 0); + fcntl(fd_, F_SETFL, flags & ~O_NONBLOCK); + } + + template + std::future send(T* buffer, size_t size) { + return send_impl(reinterpret_cast(buffer), size * sizeof(T)); + } + + template + std::future recv(T* buffer, size_t size) { + return recv_impl(reinterpret_cast(buffer), size * sizeof(T)); + } + + private: + struct SocketTask { + SocketTask(void* b, size_t s, std::promise&& p) + : buffer(b), size(s), promise(std::move(p)) {} + SocketTask(SocketTask&& t) + : buffer(t.buffer), size(t.size), promise(std::move(t.promise)) {} + void* buffer; + size_t size; + std::promise promise; + }; + + std::future send_impl(char* buffer, size_t size) { + std::promise send_completed_promise; + auto send_completed_future = send_completed_promise.get_future(); + if (size == 0) { + send_completed_promise.set_value(); + return send_completed_future; + } + + { + std::unique_lock lock(queue_mutex_); + sends_.emplace_back( + SocketTask(buffer, size, std::move(send_completed_promise))); + } + condition_.notify_one(); + return send_completed_future; + } + + std::future recv_impl(char* buffer, size_t size) { + std::promise recv_completed_promise; + auto recv_completed_future = recv_completed_promise.get_future(); + if (size == 0) { + recv_completed_promise.set_value(); + return recv_completed_future; + } + + { + std::unique_lock lock(queue_mutex_); + recvs_.emplace_back( + SocketTask(buffer, size, std::move(recv_completed_promise))); + } + condition_.notify_one(); + return recv_completed_future; + } + + bool have_tasks() { + return !(sends_.empty() && recvs_.empty()); + } + + void worker() { + bool delete_recv = false; + bool delete_send = false; + while (true) { + { + std::unique_lock lock(queue_mutex_); + + if (delete_recv) { + recvs_.front().promise.set_value(); + recvs_.pop_front(); + delete_recv = false; + } + if (delete_send) { + sends_.front().promise.set_value(); + sends_.pop_front(); + delete_send = false; + } + + if (stop_) { + return; + } + + if (!have_tasks()) { + condition_.wait(lock, [this] { return stop_ || have_tasks(); }); + if (stop_) { + return; + } + } + } + + if (!recvs_.empty()) { + auto& task = recvs_.front(); + ssize_t r = ::recv(fd_, task.buffer, task.size, 0); + if (r >= 0) { + task.buffer = static_cast(task.buffer) + r; + task.size -= r; + delete_recv = task.size == 0; + } else if (errno != EAGAIN) { + log_info( + true, "Receiving from socket", fd_, "failed with errno", errno); + return; + } + } + if (!sends_.empty()) { + auto& task = sends_.front(); + ssize_t r = ::send(fd_, task.buffer, task.size, 0); + if (r >= 0) { + task.buffer = static_cast(task.buffer) + r; + task.size -= r; + delete_send = task.size == 0; + } else if (errno != EAGAIN) { + log_info(true, "Sending to socket", fd_, "failed with errno", errno); + return; + } + } + } + } + + int fd_; + bool stop_; + std::thread worker_; + std::mutex queue_mutex_; + std::condition_variable condition_; + std::list sends_; + std::list recvs_; +}; + +class CommunicationThreads { + public: + void add(const std::vector& sockets) { + for (int sock : sockets) { + threads_.emplace(sock, sock); + } + } + + template + std::future send(int socket, T* buffer, size_t size) { + return threads_.at(socket).send(buffer, size); + } + + template + std::future recv(int socket, T* buffer, size_t size) { + return threads_.at(socket).recv(buffer, size); + } + + private: + std::unordered_map threads_; +}; + struct address_t { sockaddr_storage addr; socklen_t len; @@ -378,140 +509,6 @@ void sum_inplace(const T* input, T* output, size_t N) { } } -template -void _send(int sock, T* data, size_t start, size_t stop) { - if (stop <= start) { - return; - } - data += start; - size_t len = (stop - start) * sizeof(T); - const char* buffer = (const char*)data; - while (len > 0) { - ssize_t r = send(sock, buffer, len, 0); - if (r <= 0) { - std::ostringstream msg; - msg << "Send of " << len << " bytes failed (errno: " << errno << ")"; - throw std::runtime_error(msg.str()); - } - buffer += r; - len -= r; - } -} - -template -void _recv(int sock, T* data, size_t start, size_t stop) { - if (stop <= start) { - return; - } - data += start; - size_t len = (stop - start) * sizeof(T); - char* buffer = (char*)data; - while (len > 0) { - ssize_t r = recv(sock, buffer, len, 0); - if (r <= 0) { - std::ostringstream msg; - msg << "Recv of " << len << " bytes failed (errno: " << errno << ")"; - throw std::runtime_error(msg.str()); - } - buffer += r; - len -= r; - } -} - -template -void _recv_sum(int sock, T* data, size_t start, size_t stop) { - if (stop <= start) { - return; - } - data += start; - char buffer[PACKET_SIZE]; - size_t len = (stop - start) * sizeof(T); - while (len > 0) { - ssize_t r = 0; - do { - ssize_t partial_r = - recv(sock, buffer + r, std::min(len, PACKET_SIZE) - r, 0); - if (partial_r <= 0) { - std::ostringstream msg; - msg << "Recv of " << len << " bytes failed (errno: " << errno << ")"; - throw std::runtime_error(msg.str()); - } - r += partial_r; - } while (r % sizeof(T)); - sum_inplace((const T*)buffer, data, r / sizeof(T)); - data += r / sizeof(T); - len -= r; - } -} - -template -void ring_send( - Barrier& barrier, - int socket, - int rank, - int size, - T* data, - size_t data_size, - int direction = -1) { - // We split the data into `size_` segments of size `segment_size` - size_t segment_size = ceildiv(data_size, size); - - // Initial segment - int segment = rank; - - // 1st send - for (int i = 0; i < size - 1; i++) { - size_t start = segment * segment_size; - size_t stop = std::min((segment + 1) * segment_size, data_size); - _send(socket, data, start, stop); - barrier.arrive_and_wait(); - segment = (segment + size + direction) % size; - } - - // 2nd send - for (int i = 0; i < size - 1; i++) { - size_t start = segment * segment_size; - size_t stop = std::min((segment + 1) * segment_size, data_size); - _send(socket, data, start, stop); - barrier.arrive_and_wait(); - segment = (segment + size + direction) % size; - } -} - -template -void ring_recv_sum( - Barrier& barrier, - int socket, - int rank, - int size, - T* data, - size_t data_size, - int direction = -1) { - // We split the data into `size_` segments of size `segment_size` - size_t segment_size = ceildiv(data_size, size); - - // Initial segment - int segment = (rank + size + direction) % size; - - // Recv sum - for (int i = 0; i < size - 1; i++) { - size_t start = segment * segment_size; - size_t stop = std::min((segment + 1) * segment_size, data_size); - _recv_sum(socket, data, start, stop); - barrier.arrive_and_wait(); - segment = (segment + size + direction) % size; - } - - // Recv - for (int i = 0; i < size - 1; i++) { - size_t start = segment * segment_size; - size_t stop = std::min((segment + 1) * segment_size, data_size); - _recv(socket, data, start, stop); - barrier.arrive_and_wait(); - segment = (segment + size + direction) % size; - } -} - } // namespace class RingGroup : public GroupImpl { @@ -530,50 +527,59 @@ class RingGroup : public GroupImpl { // first and accept after. if (rank_ < connect_to) { log_info(verbose_, "Rank", rank_, "accepting"); - recv_sockets_ = std::move(accept_connections(nodes[rank_])); + sockets_left_ = std::move(accept_connections(nodes[rank_])); log_info(verbose_, "Rank", rank_, "connecting to", connect_to); - send_sockets_ = std::move(make_connections(nodes[connect_to], verbose)); + sockets_right_ = std::move(make_connections(nodes[connect_to], verbose)); } else { log_info(verbose_, "Rank", rank_, "connecting to", connect_to); - send_sockets_ = std::move(make_connections(nodes[connect_to], verbose)); + sockets_right_ = std::move(make_connections(nodes[connect_to], verbose)); log_info(verbose_, "Rank", rank_, "accepting"); - recv_sockets_ = std::move(accept_connections(nodes[rank_])); + sockets_left_ = std::move(accept_connections(nodes[rank_])); } - // Failure if we couldn't make send or recv sockets - if (send_sockets_.empty()) { + // Failure if we couldn't make right or left sockets + if (sockets_right_.empty()) { std::ostringstream msg; - msg << "[ring] Rank " << rank_ << " has no send sockets."; + msg << "[ring] Rank " << rank_ << " has no sockets to the right."; throw std::invalid_argument(msg.str()); } - if (recv_sockets_.empty()) { + if (sockets_left_.empty()) { std::ostringstream msg; - msg << "[ring] Rank " << rank_ << " has no recv sockets."; + msg << "[ring] Rank " << rank_ << " has no sockets to the left."; throw std::invalid_argument(msg.str()); } // The following could be relaxed since we can define non-homogeneous rings // but it makes things a bit simpler for now. - if (send_sockets_.size() != recv_sockets_.size()) { + if (sockets_right_.size() != sockets_left_.size()) { std::ostringstream msg; msg << "[ring] It is required to have as many connections to the left as " << "to the right but rank " << rank_ << " has " - << send_sockets_.size() << " connections to the right and " - << recv_sockets_.size() << " to the left."; + << sockets_right_.size() << " connections to the right and " + << sockets_left_.size() << " to the left."; throw std::invalid_argument(msg.str()); } - // Start the necessary threads for completely parallel operation on all - // channels. One thread to send, one to receive per socket. - pool_.resize(send_sockets_.size() * 2 * 2); + // Start the all reduce threads. One all reduce per direction per ring. + pool_.resize(sockets_right_.size() + sockets_left_.size()); + + // Create a communication thread per socket. This also converts them to + // non-blocking. + comm_.add(sockets_right_); + comm_.add(sockets_left_); + + // Allocate buffers for the all sum + buffers_.resize( + (sockets_right_.size() + sockets_left_.size()) * ALL_SUM_BUFFERS * + ALL_SUM_SIZE); } ~RingGroup() { - for (auto s : send_sockets_) { + for (auto s : sockets_right_) { shutdown(s, 2); close(s); } - for (auto s : recv_sockets_) { + for (auto s : sockets_left_) { shutdown(s, 2); close(s); } @@ -594,14 +600,42 @@ class RingGroup : public GroupImpl { std::shared_ptr split(int color, int key = -1) override { throw std::runtime_error("[ring] Group split not supported."); } + void all_gather(const array& input, array& output) override { throw std::runtime_error("[ring] All gather not supported."); } - void send(const array& input, int dst) override { - throw std::runtime_error("[ring] Send not supported."); + + void send(const array& input_, int dst) override { + // Make sure that the input is row contiguous + array input = ensure_row_contiguous(input_); + + int right = (rank_ + 1) % size_; + int left = (rank_ + size_ - 1) % size_; + if (dst == right) { + send(sockets_right_, input.data(), input.nbytes()); + } else if (dst == left) { + send(sockets_left_, input.data(), input.nbytes()); + } else { + std::ostringstream msg; + msg << "[ring] Send only supported to direct neighbors " + << "but tried to send to " << dst << " from " << rank_ << std::endl; + throw std::runtime_error(msg.str()); + } } + void recv(array& out, int src) override { - throw std::runtime_error("[ring] Recv not supported."); + int right = (rank_ + 1) % size_; + int left = (rank_ + size_ - 1) % size_; + if (src == right) { + recv(sockets_right_, out.data(), out.nbytes()); + } else if (src == left) { + recv(sockets_left_, out.data(), out.nbytes()); + } else { + std::ostringstream msg; + msg << "[ring] Recv only supported from direct neighbors " + << "but tried to recv from " << src << " to " << rank_ << std::endl; + throw std::runtime_error(msg.str()); + } } private: @@ -613,7 +647,8 @@ class RingGroup : public GroupImpl { // If the input data cannot be split into size_ segments then copy it and // all reduce a local buffer prefilled with 0s. if (input.size() < size_) { - // TODO: Maybe allocate dynamically so we don't have the constraint below? + // TODO: Maybe allocate dynamically so we don't have the constraint + // below? if (input.itemsize() * size_ > 1024) { std::ostringstream msg; msg << "Can't perform the ring all reduce of " << output.size() @@ -621,31 +656,16 @@ class RingGroup : public GroupImpl { throw std::runtime_error(msg.str()); } - std::future sent, recvd; - auto barrier = std::make_unique(2); char buffer[1024]; std::memset(buffer, 0, size_ * input.itemsize()); std::memcpy(buffer, input.data(), input.nbytes()); - sent = pool_.enqueue( - ring_send, - std::reference_wrapper(*barrier), - send_sockets_[0], - rank_, - size_, - (T*)buffer, + all_sum_impl( + reinterpret_cast(buffers_.data()), + reinterpret_cast(buffer), size_, + sockets_right_[0], + sockets_left_[0], -1); - recvd = pool_.enqueue( - ring_recv_sum, - std::reference_wrapper(*barrier), - recv_sockets_[0], - rank_, - size_, - (T*)buffer, - size_, - -1); - sent.wait(); - recvd.wait(); std::memcpy(output.data(), buffer, output.nbytes()); return; } @@ -655,137 +675,155 @@ class RingGroup : public GroupImpl { std::memcpy(output.data(), input.data(), input.nbytes()); } - // All reduce in place. We have `send_channels_.size()` bidirectional - // channels so let's split the message up and perform as many parallel - // ring-reductions as possible. - std::vector> reductions; - std::vector> barriers; - size_t packets = ceildiv(output.size(), size_ * PACKET_SIZE); + // Split the all reduces so that each member has at least 1 buffer to + // send/recv per segment. + constexpr size_t min_send_size = 262144; + size_t n_reduces = std::max( + std::min( + sockets_right_.size() + sockets_left_.size(), + output.nbytes() / (size_ * min_send_size)), + 1UL); + size_t step = ceildiv(output.size(), n_reduces); + std::vector> all_sums; - // Large all reduce territory so let's use all we got - if (packets >= 2 * send_sockets_.size()) { - size_t segment = ceildiv(output.size(), 2 * send_sockets_.size()); - for (int i = 0; i < send_sockets_.size(); i++) { - // 1st ring reduce - barriers.emplace_back(std::make_unique(2)); - reductions.push_back(pool_.enqueue( - ring_send, - std::reference_wrapper(*barriers.back()), - send_sockets_[i], - rank_, - size_, - output.data() + 2 * i * segment, - std::min(output.size() - 2 * i * segment, segment), - -1)); - reductions.push_back(pool_.enqueue( - ring_recv_sum, - std::reference_wrapper(*barriers.back()), - recv_sockets_[i], - rank_, - size_, - output.data() + 2 * i * segment, - std::min(output.size() - 2 * i * segment, segment), - -1)); + for (int i = 0; i < n_reduces; i++) { + all_sums.emplace_back(pool_.enqueue(std::bind( + &RingGroup::all_sum_impl, + this, + reinterpret_cast( + buffers_.data() + i * ALL_SUM_SIZE * ALL_SUM_BUFFERS), + output.data() + i * step, + std::min(output.size(), (i + 1) * step) - i * step, + sockets_right_[i / 2], + sockets_left_[i / 2], + (i % 2) ? -1 : 1))); + } + for (auto& f : all_sums) { + f.wait(); + } + } - // 2nd ring reduce - barriers.emplace_back(std::make_unique(2)); - reductions.push_back(pool_.enqueue( - ring_send, - std::reference_wrapper(*barriers.back()), - recv_sockets_[i], - rank_, - size_, - output.data() + (2 * i + 1) * segment, - std::min(output.size() - (2 * i + 1) * segment, segment), - 1)); - reductions.push_back(pool_.enqueue( - ring_recv_sum, - std::reference_wrapper(*barriers.back()), - send_sockets_[i], - rank_, - size_, - output.data() + (2 * i + 1) * segment, - std::min(output.size() - (2 * i + 1) * segment, segment), - 1)); + template + void all_sum_impl( + T* buffer, + T* data, + size_t data_size, + int socket_right, + int socket_left, + int direction) { + // Choose which socket we send to and recv from + int socket_send = (direction < 0) ? socket_right : socket_left; + int socket_recv = (direction < 0) ? socket_left : socket_right; + + // We split the data into `size_` segments of size `segment_size` and each + // of these in smaller segments of ALL_SUM_SIZE which we 'll call packets. + size_t segment_size = ceildiv(data_size, size_); + size_t BUFFER_SIZE = + std::max(32768UL, std::min(ALL_SUM_SIZE / sizeof(T), segment_size / 2)); + size_t n_packets = ceildiv(segment_size, BUFFER_SIZE); + + // Initial segments + int send_segment = rank_; + int recv_segment = (rank_ + direction + size_) % size_; + + // Plan the whole reduce in terms of sends and recvs as indices in data. + // It makes the actual async send and recv a bit simpler to follow when + // there are less offset calculations around. + std::vector> send_plan; + std::vector> recv_plan; + + // Two times the same send/recv operations, first scatter reduce and then + // gather. + for (int k = 0; k < 2; k++) { + for (int i = 0; i < size_ - 1; i++) { + size_t send_start = send_segment * segment_size; + size_t send_stop = + std::min((send_segment + 1) * segment_size, data_size); + size_t recv_start = recv_segment * segment_size; + size_t recv_stop = + std::min((recv_segment + 1) * segment_size, data_size); + + for (size_t j = 0; j < n_packets; j++) { + send_plan.emplace_back( + std::min(send_start + j * BUFFER_SIZE, send_stop), + std::min(send_start + (j + 1) * BUFFER_SIZE, send_stop)); + recv_plan.emplace_back( + std::min(recv_start + j * BUFFER_SIZE, recv_stop), + std::min(recv_start + (j + 1) * BUFFER_SIZE, recv_stop)); + } + + send_segment = (send_segment + size_ + direction) % size_; + recv_segment = (recv_segment + size_ + direction) % size_; } } - // At least 2 reductions so we can be from small to medium - else if (packets > 1) { - size_t segment = ceildiv(output.size(), packets); - for (int i = 0; i < send_sockets_.size(); i++) { - barriers.emplace_back(std::make_unique(2)); - reductions.push_back(pool_.enqueue( - ring_send, - std::reference_wrapper(*barriers.back()), - send_sockets_[i], - rank_, - size_, - output.data() + i * segment, - std::min(output.size() - i * segment, segment), - -1)); - reductions.push_back(pool_.enqueue( - ring_recv_sum, - std::reference_wrapper(*barriers.back()), - recv_sockets_[i], - rank_, - size_, - output.data() + i * segment, - std::min(output.size() - i * segment, segment), - -1)); - } - for (int i = 0; i < packets - send_sockets_.size(); i++) { - barriers.emplace_back(std::make_unique(2)); - reductions.push_back(pool_.enqueue( - ring_send, - std::reference_wrapper(*barriers.back()), - recv_sockets_[i], - rank_, - size_, - output.data() + (send_sockets_.size() + i) * segment, - std::min( - output.size() - (send_sockets_.size() + i) * segment, segment), - 1)); - reductions.push_back(pool_.enqueue( - ring_recv_sum, - std::reference_wrapper(*barriers.back()), - send_sockets_[i], - rank_, - size_, - output.data() + (send_sockets_.size() + i) * segment, - std::min( - output.size() - (send_sockets_.size() + i) * segment, segment), - 1)); - } + // Running the plan is fairly simple, we keep a send and a recv in flight + // while doing the summation. + T* recv_buffers[ALL_SUM_BUFFERS]; + for (int i = 0; i < ALL_SUM_BUFFERS; i++) { + recv_buffers[i] = buffer + i * BUFFER_SIZE; } + std::future sends[2], recvs[2]; + int a = 0; + int b = (n_packets > 1) ? 1 : 0; + for (int i = 0, j = -b; i < send_plan.size(); j++, i++) { + sends[a] = comm_.send( + socket_send, + data + send_plan[i].first, + send_plan[i].second - send_plan[i].first); + if (2 * i < send_plan.size()) { + recvs[a] = comm_.recv( + socket_recv, + recv_buffers[i % ALL_SUM_BUFFERS], + recv_plan[i].second - recv_plan[i].first); + } else { + recvs[a] = comm_.recv( + socket_recv, + data + recv_plan[i].first, + recv_plan[i].second - recv_plan[i].first); + } - // Small reduction which won't really benefit much from parallelization. - // TODO: Verify that this is true cause PACKET_SIZE * size_ can still be a - // fairly large array. - else { - barriers.emplace_back(std::make_unique(2)); - reductions.push_back(pool_.enqueue( - ring_send, - std::reference_wrapper(*barriers.back()), - send_sockets_[0], - rank_, - size_, - output.data(), - output.size(), - -1)); - reductions.push_back(pool_.enqueue( - ring_recv_sum, - std::reference_wrapper(*barriers.back()), - recv_sockets_[0], - rank_, - size_, - output.data(), - output.size(), - -1)); + if (j >= 0) { + sends[b].wait(); + recvs[b].wait(); + if (2 * j < send_plan.size()) { + sum_inplace( + recv_buffers[j % ALL_SUM_BUFFERS], + data + recv_plan[j].first, + recv_plan[j].second - recv_plan[j].first); + } + } + + std::swap(a, b); } + sends[b].wait(); + recvs[b].wait(); + } - // Wait for the reductions to finish. - for (auto& f : reductions) { + void send(const std::vector& sockets, char* data, size_t data_size) { + size_t segment_size = ceildiv(data_size, sockets.size()); + std::vector> sends; + for (int i = 0; i < sockets.size(); i++) { + sends.emplace_back(comm_.send( + sockets[i], + data + i * segment_size, + std::min(data_size, (i + 1) * segment_size) - i * segment_size)); + } + for (auto& f : sends) { + f.wait(); + } + } + + void recv(const std::vector& sockets, char* data, size_t data_size) { + size_t segment_size = ceildiv(data_size, sockets.size()); + std::vector> recvs; + for (int i = 0; i < sockets.size(); i++) { + recvs.emplace_back(comm_.recv( + sockets[i], + data + i * segment_size, + std::min(data_size, (i + 1) * segment_size) - i * segment_size)); + } + for (auto& f : recvs) { f.wait(); } } @@ -796,9 +834,12 @@ class RingGroup : public GroupImpl { bool verbose_; ThreadPool pool_; + CommunicationThreads comm_; - std::vector send_sockets_; - std::vector recv_sockets_; + std::vector sockets_right_; + std::vector sockets_left_; + + std::vector buffers_; }; bool is_available() { diff --git a/python/tests/ring_test_distributed.py b/python/tests/ring_test_distributed.py index 215ecb44a..0c68914bf 100644 --- a/python/tests/ring_test_distributed.py +++ b/python/tests/ring_test_distributed.py @@ -56,6 +56,45 @@ class TestRingDistributed(mlx_tests.MLXTestCase): maxrelerror = ((y - z).abs() / z.abs()).max() self.assertLessEqual(maxrelerror, rtol) + def test_send_recv(self): + world = mx.distributed.init() + dtypes = [ + mx.int8, + mx.uint8, + mx.int16, + mx.uint16, + mx.int32, + mx.uint32, + mx.float32, + mx.float16, + mx.bfloat16, + mx.complex64, + ] + sizes = [ + (7,), + (10,), + (1024,), + (1024, 1024), + ] + key = mx.random.key(0) + right = (world.rank() + 1) % world.size() + left = (world.rank() + world.size() - 1) % world.size() + for dt in dtypes: + for sh in sizes: + x = ( + mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10 + ).astype(dt) + if world.rank() % 2 == 0: + y = mx.distributed.send(x[world.rank()], right) + z = mx.distributed.recv_like(y, left) + mx.eval(y, z) + else: + z = mx.distributed.recv_like(x[world.rank()], left) + y = mx.distributed.send(x[world.rank()], right) + mx.eval(z, y) + self.assertTrue(mx.all(y == x[world.rank()])) + self.assertTrue(mx.all(z == x[left])) + if __name__ == "__main__": unittest.main()