Fix side channel initialization for more than 2 peers

This commit is contained in:
Angelos Katharopoulos
2025-10-14 17:48:53 -07:00
parent 4dbffb3954
commit d4c1de4a8b
3 changed files with 40 additions and 2 deletions

View File

@@ -69,6 +69,13 @@ struct Destination {
ibv_gid global_identifier;
};
std::ostream& operator<<(std::ostream& os, const Destination& dst) {
os << dst.local_id << " " << dst.queue_pair_number
<< " " << dst.packet_sequence_number << " "
<< dst.global_identifier;
return os;
}
/**
* A buffer that can be registered to a number of protection domains.
*/
@@ -370,12 +377,25 @@ class SideChannel {
for (int i = 0; i < size - 1; i++) {
sockets_.push_back(server.accept(IBV_TAG));
}
std::vector<int> ranks(size-1);
for (int i = 0; i < size - 1; i++) {
sockets_[i].recv(IBV_TAG, reinterpret_cast<char*>(&ranks[i]), sizeof(int));
ranks[i]--;
}
for (int i = 0; i < size - 1; i++) {
while (i != ranks[i]) {
std::swap(sockets_[i], sockets_[ranks[i]]);
std::swap(ranks[i], ranks[ranks[i]]);
}
}
} else {
sockets_.push_back(detail::TCPSocket::connect(
IBV_TAG, address, 4, 1000, [](int attempt, int wait) {
std::cerr << IBV_TAG << " Connection attempt " << attempt
<< " waiting " << wait << " ms" << std::endl;
}));
sockets_[0].send(IBV_TAG, reinterpret_cast<char*>(&rank_), sizeof(int));
}
}
@@ -500,7 +520,17 @@ class ConnectionManager {
allocate_buffers(num_buffers, num_bytes);
// Gather the information to be exchanged
// First init all connections
for (int peer = 0; peer < size_; peer++) {
if (peer == rank_) {
continue;
}
connections_[peer].queue_pair_init();
}
// Gather the information to be exchanged, this also serves as a barrier so
// that all peers have initialized their connections before attempting to
// transition to RTS.
std::vector<Destination> info;
for (auto& conn : connections_) {
info.emplace_back(conn.info());
@@ -513,7 +543,6 @@ class ConnectionManager {
continue;
}
auto peer_info = all_infos[peer][rank_];
connections_[peer].queue_pair_init();
connections_[peer].queue_pair_rtr(peer_info);
connections_[peer].queue_pair_rts();
}