mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix side channel initialization for more than 2 peers
This commit is contained in:
@@ -69,6 +69,13 @@ struct Destination {
|
||||
ibv_gid global_identifier;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Destination& dst) {
|
||||
os << dst.local_id << " " << dst.queue_pair_number
|
||||
<< " " << dst.packet_sequence_number << " "
|
||||
<< dst.global_identifier;
|
||||
return os;
|
||||
}
|
||||
|
||||
/**
|
||||
* A buffer that can be registered to a number of protection domains.
|
||||
*/
|
||||
@@ -370,12 +377,25 @@ class SideChannel {
|
||||
for (int i = 0; i < size - 1; i++) {
|
||||
sockets_.push_back(server.accept(IBV_TAG));
|
||||
}
|
||||
|
||||
std::vector<int> ranks(size-1);
|
||||
for (int i = 0; i < size - 1; i++) {
|
||||
sockets_[i].recv(IBV_TAG, reinterpret_cast<char*>(&ranks[i]), sizeof(int));
|
||||
ranks[i]--;
|
||||
}
|
||||
for (int i = 0; i < size - 1; i++) {
|
||||
while (i != ranks[i]) {
|
||||
std::swap(sockets_[i], sockets_[ranks[i]]);
|
||||
std::swap(ranks[i], ranks[ranks[i]]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
sockets_.push_back(detail::TCPSocket::connect(
|
||||
IBV_TAG, address, 4, 1000, [](int attempt, int wait) {
|
||||
std::cerr << IBV_TAG << " Connection attempt " << attempt
|
||||
<< " waiting " << wait << " ms" << std::endl;
|
||||
}));
|
||||
sockets_[0].send(IBV_TAG, reinterpret_cast<char*>(&rank_), sizeof(int));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -500,7 +520,17 @@ class ConnectionManager {
|
||||
|
||||
allocate_buffers(num_buffers, num_bytes);
|
||||
|
||||
// Gather the information to be exchanged
|
||||
// First init all connections
|
||||
for (int peer = 0; peer < size_; peer++) {
|
||||
if (peer == rank_) {
|
||||
continue;
|
||||
}
|
||||
connections_[peer].queue_pair_init();
|
||||
}
|
||||
|
||||
// Gather the information to be exchanged, this also serves as a barrier so
|
||||
// that all peers have initialized their connections before attempting to
|
||||
// transition to RTS.
|
||||
std::vector<Destination> info;
|
||||
for (auto& conn : connections_) {
|
||||
info.emplace_back(conn.info());
|
||||
@@ -513,7 +543,6 @@ class ConnectionManager {
|
||||
continue;
|
||||
}
|
||||
auto peer_info = all_infos[peer][rank_];
|
||||
connections_[peer].queue_pair_init();
|
||||
connections_[peer].queue_pair_rtr(peer_info);
|
||||
connections_[peer].queue_pair_rts();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user