diff --git a/mlx/distributed/ibv/ibv.cpp b/mlx/distributed/ibv/ibv.cpp index 37e5d1113..1b655db78 100644 --- a/mlx/distributed/ibv/ibv.cpp +++ b/mlx/distributed/ibv/ibv.cpp @@ -69,6 +69,13 @@ struct Destination { ibv_gid global_identifier; }; +std::ostream& operator<<(std::ostream& os, const Destination& dst) { + os << dst.local_id << " " << dst.queue_pair_number + << " " << dst.packet_sequence_number << " " + << dst.global_identifier; + return os; +} + /** * A buffer that can be registered to a number of protection domains. */ @@ -370,12 +377,25 @@ class SideChannel { for (int i = 0; i < size - 1; i++) { sockets_.push_back(server.accept(IBV_TAG)); } + + std::vector ranks(size-1); + for (int i = 0; i < size - 1; i++) { + sockets_[i].recv(IBV_TAG, reinterpret_cast(&ranks[i]), sizeof(int)); + ranks[i]--; + } + for (int i = 0; i < size - 1; i++) { + while (i != ranks[i]) { + std::swap(sockets_[i], sockets_[ranks[i]]); + std::swap(ranks[i], ranks[ranks[i]]); + } + } } else { sockets_.push_back(detail::TCPSocket::connect( IBV_TAG, address, 4, 1000, [](int attempt, int wait) { std::cerr << IBV_TAG << " Connection attempt " << attempt << " waiting " << wait << " ms" << std::endl; })); + sockets_[0].send(IBV_TAG, reinterpret_cast(&rank_), sizeof(int)); } } @@ -500,7 +520,17 @@ class ConnectionManager { allocate_buffers(num_buffers, num_bytes); - // Gather the information to be exchanged + // First init all connections + for (int peer = 0; peer < size_; peer++) { + if (peer == rank_) { + continue; + } + connections_[peer].queue_pair_init(); + } + + // Gather the information to be exchanged, this also serves as a barrier so + // that all peers have initialized their connections before attempting to + // transition to RTS. std::vector info; for (auto& conn : connections_) { info.emplace_back(conn.info()); @@ -513,7 +543,6 @@ class ConnectionManager { continue; } auto peer_info = all_infos[peer][rank_]; - connections_[peer].queue_pair_init(); connections_[peer].queue_pair_rtr(peer_info); connections_[peer].queue_pair_rts(); } diff --git a/mlx/distributed/utils.cpp b/mlx/distributed/utils.cpp index 1598694c2..a91fa2488 100644 --- a/mlx/distributed/utils.cpp +++ b/mlx/distributed/utils.cpp @@ -63,6 +63,14 @@ TCPSocket::TCPSocket(TCPSocket&& s) { s.sock_ = -1; } +TCPSocket& TCPSocket::operator=(TCPSocket&& s) { + if (this != &s) { + sock_ = s.sock_; + s.sock_ = -1; + } + return *this; +} + TCPSocket::TCPSocket(int s) : sock_(s) {} TCPSocket::~TCPSocket() { diff --git a/mlx/distributed/utils.h b/mlx/distributed/utils.h index ef01cf09a..ce0220931 100644 --- a/mlx/distributed/utils.h +++ b/mlx/distributed/utils.h @@ -34,6 +34,7 @@ class TCPSocket { TCPSocket(const TCPSocket&) = delete; TCPSocket& operator=(const TCPSocket&) = delete; TCPSocket(TCPSocket&& s); + TCPSocket& operator=(TCPSocket&&); ~TCPSocket(); void listen(const char* tag, const address_t& addr);