Fix side channel initialization for more than 2 peers

This commit is contained in:
Angelos Katharopoulos
2025-10-14 17:48:53 -07:00
parent 4dbffb3954
commit d4c1de4a8b
3 changed files with 40 additions and 2 deletions

View File

@@ -69,6 +69,13 @@ struct Destination {
ibv_gid global_identifier; ibv_gid global_identifier;
}; };
std::ostream& operator<<(std::ostream& os, const Destination& dst) {
os << dst.local_id << " " << dst.queue_pair_number
<< " " << dst.packet_sequence_number << " "
<< dst.global_identifier;
return os;
}
/** /**
* A buffer that can be registered to a number of protection domains. * A buffer that can be registered to a number of protection domains.
*/ */
@@ -370,12 +377,25 @@ class SideChannel {
for (int i = 0; i < size - 1; i++) { for (int i = 0; i < size - 1; i++) {
sockets_.push_back(server.accept(IBV_TAG)); sockets_.push_back(server.accept(IBV_TAG));
} }
std::vector<int> ranks(size-1);
for (int i = 0; i < size - 1; i++) {
sockets_[i].recv(IBV_TAG, reinterpret_cast<char*>(&ranks[i]), sizeof(int));
ranks[i]--;
}
for (int i = 0; i < size - 1; i++) {
while (i != ranks[i]) {
std::swap(sockets_[i], sockets_[ranks[i]]);
std::swap(ranks[i], ranks[ranks[i]]);
}
}
} else { } else {
sockets_.push_back(detail::TCPSocket::connect( sockets_.push_back(detail::TCPSocket::connect(
IBV_TAG, address, 4, 1000, [](int attempt, int wait) { IBV_TAG, address, 4, 1000, [](int attempt, int wait) {
std::cerr << IBV_TAG << " Connection attempt " << attempt std::cerr << IBV_TAG << " Connection attempt " << attempt
<< " waiting " << wait << " ms" << std::endl; << " waiting " << wait << " ms" << std::endl;
})); }));
sockets_[0].send(IBV_TAG, reinterpret_cast<char*>(&rank_), sizeof(int));
} }
} }
@@ -500,7 +520,17 @@ class ConnectionManager {
allocate_buffers(num_buffers, num_bytes); allocate_buffers(num_buffers, num_bytes);
// Gather the information to be exchanged // First init all connections
for (int peer = 0; peer < size_; peer++) {
if (peer == rank_) {
continue;
}
connections_[peer].queue_pair_init();
}
// Gather the information to be exchanged, this also serves as a barrier so
// that all peers have initialized their connections before attempting to
// transition to RTS.
std::vector<Destination> info; std::vector<Destination> info;
for (auto& conn : connections_) { for (auto& conn : connections_) {
info.emplace_back(conn.info()); info.emplace_back(conn.info());
@@ -513,7 +543,6 @@ class ConnectionManager {
continue; continue;
} }
auto peer_info = all_infos[peer][rank_]; auto peer_info = all_infos[peer][rank_];
connections_[peer].queue_pair_init();
connections_[peer].queue_pair_rtr(peer_info); connections_[peer].queue_pair_rtr(peer_info);
connections_[peer].queue_pair_rts(); connections_[peer].queue_pair_rts();
} }

View File

@@ -63,6 +63,14 @@ TCPSocket::TCPSocket(TCPSocket&& s) {
s.sock_ = -1; s.sock_ = -1;
} }
TCPSocket& TCPSocket::operator=(TCPSocket&& s) {
if (this != &s) {
sock_ = s.sock_;
s.sock_ = -1;
}
return *this;
}
TCPSocket::TCPSocket(int s) : sock_(s) {} TCPSocket::TCPSocket(int s) : sock_(s) {}
TCPSocket::~TCPSocket() { TCPSocket::~TCPSocket() {

View File

@@ -34,6 +34,7 @@ class TCPSocket {
TCPSocket(const TCPSocket&) = delete; TCPSocket(const TCPSocket&) = delete;
TCPSocket& operator=(const TCPSocket&) = delete; TCPSocket& operator=(const TCPSocket&) = delete;
TCPSocket(TCPSocket&& s); TCPSocket(TCPSocket&& s);
TCPSocket& operator=(TCPSocket&&);
~TCPSocket(); ~TCPSocket();
void listen(const char* tag, const address_t& addr); void listen(const char* tag, const address_t& addr);