mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix side channel initialization for more than 2 peers
This commit is contained in:
@@ -69,6 +69,13 @@ struct Destination {
|
|||||||
ibv_gid global_identifier;
|
ibv_gid global_identifier;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& os, const Destination& dst) {
|
||||||
|
os << dst.local_id << " " << dst.queue_pair_number
|
||||||
|
<< " " << dst.packet_sequence_number << " "
|
||||||
|
<< dst.global_identifier;
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A buffer that can be registered to a number of protection domains.
|
* A buffer that can be registered to a number of protection domains.
|
||||||
*/
|
*/
|
||||||
@@ -370,12 +377,25 @@ class SideChannel {
|
|||||||
for (int i = 0; i < size - 1; i++) {
|
for (int i = 0; i < size - 1; i++) {
|
||||||
sockets_.push_back(server.accept(IBV_TAG));
|
sockets_.push_back(server.accept(IBV_TAG));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int> ranks(size-1);
|
||||||
|
for (int i = 0; i < size - 1; i++) {
|
||||||
|
sockets_[i].recv(IBV_TAG, reinterpret_cast<char*>(&ranks[i]), sizeof(int));
|
||||||
|
ranks[i]--;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < size - 1; i++) {
|
||||||
|
while (i != ranks[i]) {
|
||||||
|
std::swap(sockets_[i], sockets_[ranks[i]]);
|
||||||
|
std::swap(ranks[i], ranks[ranks[i]]);
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
sockets_.push_back(detail::TCPSocket::connect(
|
sockets_.push_back(detail::TCPSocket::connect(
|
||||||
IBV_TAG, address, 4, 1000, [](int attempt, int wait) {
|
IBV_TAG, address, 4, 1000, [](int attempt, int wait) {
|
||||||
std::cerr << IBV_TAG << " Connection attempt " << attempt
|
std::cerr << IBV_TAG << " Connection attempt " << attempt
|
||||||
<< " waiting " << wait << " ms" << std::endl;
|
<< " waiting " << wait << " ms" << std::endl;
|
||||||
}));
|
}));
|
||||||
|
sockets_[0].send(IBV_TAG, reinterpret_cast<char*>(&rank_), sizeof(int));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -500,7 +520,17 @@ class ConnectionManager {
|
|||||||
|
|
||||||
allocate_buffers(num_buffers, num_bytes);
|
allocate_buffers(num_buffers, num_bytes);
|
||||||
|
|
||||||
// Gather the information to be exchanged
|
// First init all connections
|
||||||
|
for (int peer = 0; peer < size_; peer++) {
|
||||||
|
if (peer == rank_) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
connections_[peer].queue_pair_init();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gather the information to be exchanged, this also serves as a barrier so
|
||||||
|
// that all peers have initialized their connections before attempting to
|
||||||
|
// transition to RTS.
|
||||||
std::vector<Destination> info;
|
std::vector<Destination> info;
|
||||||
for (auto& conn : connections_) {
|
for (auto& conn : connections_) {
|
||||||
info.emplace_back(conn.info());
|
info.emplace_back(conn.info());
|
||||||
@@ -513,7 +543,6 @@ class ConnectionManager {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto peer_info = all_infos[peer][rank_];
|
auto peer_info = all_infos[peer][rank_];
|
||||||
connections_[peer].queue_pair_init();
|
|
||||||
connections_[peer].queue_pair_rtr(peer_info);
|
connections_[peer].queue_pair_rtr(peer_info);
|
||||||
connections_[peer].queue_pair_rts();
|
connections_[peer].queue_pair_rts();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -63,6 +63,14 @@ TCPSocket::TCPSocket(TCPSocket&& s) {
|
|||||||
s.sock_ = -1;
|
s.sock_ = -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TCPSocket& TCPSocket::operator=(TCPSocket&& s) {
|
||||||
|
if (this != &s) {
|
||||||
|
sock_ = s.sock_;
|
||||||
|
s.sock_ = -1;
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
TCPSocket::TCPSocket(int s) : sock_(s) {}
|
TCPSocket::TCPSocket(int s) : sock_(s) {}
|
||||||
|
|
||||||
TCPSocket::~TCPSocket() {
|
TCPSocket::~TCPSocket() {
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ class TCPSocket {
|
|||||||
TCPSocket(const TCPSocket&) = delete;
|
TCPSocket(const TCPSocket&) = delete;
|
||||||
TCPSocket& operator=(const TCPSocket&) = delete;
|
TCPSocket& operator=(const TCPSocket&) = delete;
|
||||||
TCPSocket(TCPSocket&& s);
|
TCPSocket(TCPSocket&& s);
|
||||||
|
TCPSocket& operator=(TCPSocket&&);
|
||||||
~TCPSocket();
|
~TCPSocket();
|
||||||
|
|
||||||
void listen(const char* tag, const address_t& addr);
|
void listen(const char* tag, const address_t& addr);
|
||||||
|
|||||||
Reference in New Issue
Block a user