From 3fe98bacc7640d857acf3539f1d21b47a32e5609 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sun, 8 Sep 2024 23:21:02 -0700 Subject: [PATCH] Raw sockets --- mlx/distributed/sockets/sockets.cpp | 303 +++++++++++++--------------- 1 file changed, 142 insertions(+), 161 deletions(-) diff --git a/mlx/distributed/sockets/sockets.cpp b/mlx/distributed/sockets/sockets.cpp index 0b9b38c58..9d60321a2 100644 --- a/mlx/distributed/sockets/sockets.cpp +++ b/mlx/distributed/sockets/sockets.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -73,9 +74,9 @@ } break; \ } -constexpr const size_t PACKET_SIZE = 262144; -constexpr const int CONN_ATTEMPTS = 5; -constexpr const int CONN_WAIT = 1000; +constexpr const size_t PACKET_SIZE = 1408; +constexpr const uint16_t ETHER_TYPE = 32923; +constexpr const uint16_t ETHER_TYPE_NTOHS = ntohs(ETHER_TYPE); using json = nlohmann::json; @@ -107,170 +108,122 @@ array ensure_row_contiguous(const array& arr) { } } -struct address_t { - sockaddr_storage addr; - socklen_t len; +struct mac_address { + uint8_t raw[6] = {0}; - const sockaddr* sockaddr() { - return (struct sockaddr*)&addr; + mac_address(const std::string& address) { + auto hex_to_int = [](const char c) -> uint8_t { + if (c >= 'a' && c <= 'f') { + return c - 'a' + 10; + } + + if (c >= 'A' && c <= 'F') { + return c - 'A' + 10; + } + + if (c >= '0' && c <= '9') { + return c - '0'; + } + + return 0; + }; + + int idx = 0; + int cnt = 0; + for (const auto c : address) { + if (c == ':') { + idx += 1; + cnt = 0; + if (idx >= 6) { + break; + } + } else { + raw[idx] <<= 4 * cnt; + raw[idx] += hex_to_int(c); + } + } + } + + void to_buffer(char* buf) { + for (int i = 0; i < 6; i++) { + buf[i] = ((char*)raw)[i]; + } } }; -address_t parse_address(std::string ip, std::string port) { - struct addrinfo hints, *res; - memset(&hints, 0, sizeof(hints)); - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - - int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res); - if (status != 0) { - std::ostringstream msg; - msg << "Can't parse peer address " << ip << ":" << port; - throw std::runtime_error(msg.str()); - } - - address_t result; - memcpy(&result.addr, res->ai_addr, res->ai_addrlen); - result.len = res->ai_addrlen; - freeaddrinfo(res); - - return result; -} - -std::vector load_peers() { - std::vector peers; +std::pair> parse_config() { + std::vector peers; std::ifstream f; if (const char* hostfile_buf = std::getenv("MLX_HOSTFILE")) { f.open(hostfile_buf); } else { - return peers; + return {"lo0", peers}; } - json hosts = json::parse(f); - for (auto& h : hosts) { - peers.push_back(std::move(parse_address( - h["ip"].template get(), - h["port"].template get()))); + json config = json::parse(f); + for (auto& h : config["peers"]) { + peers.emplace_back(h.get()); } - return peers; + return {config["interface"].get(), peers}; } struct GroupImpl { - GroupImpl(std::vector peers, int rank, bool global) - : rank_(rank), global_(global), pool_(4), sockets_(peers.size(), -1) { - if (rank_ > 0 && rank_ >= peers.size()) { + GroupImpl( + const std::string& interface, + std::vector peers, + int rank, + bool global) + : rank_(rank), global_(global), pool_(1), peers_(std::move(peers)) { + if (rank_ > 0 && rank_ >= peers_.size()) { throw std::runtime_error( "Rank cannot be larger than the size of the group"); } - int success; - - // If we are expecting anyone to connect to us - if (rank_ + 1 < peers.size()) { - // Create the socket to wait for connections from the peers - int sock = socket(AF_INET, SOCK_STREAM, 0); - if (sock < 0) { - std::ostringstream msg; - msg << "Couldn't create socket (error: " << errno << ")"; - throw std::runtime_error(msg.str()); - } - - // Make sure we can launch immediately after shutdown by setting the - // reuseaddr option so that we don't get address already in use errors - int enable = 1; - success = - setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)); - if (success < 0) { - shutdown(sock, 2); - close(sock); - std::ostringstream msg; - msg << "Couldn't enable reuseaddr (rank: " << rank_ - << " error: " << errno << ")"; - throw std::runtime_error(msg.str()); - } - success = - setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int)); - if (success < 0) { - shutdown(sock, 2); - close(sock); - std::ostringstream msg; - msg << "Couldn't enable reuseport (rank: " << rank_ - << " error: " << errno << ")"; - throw std::runtime_error(msg.str()); - } - - // Bind it to the port - success = bind(sock, peers[rank_].sockaddr(), peers[rank_].len); - if (success < 0) { - shutdown(sock, 2); - close(sock); - std::ostringstream msg; - msg << "Couldn't bind socket (rank: " << rank_ << " error: " << errno - << ")"; - throw std::runtime_error(msg.str()); - } - - // Wait for connections - success = listen(sock, 0); - if (success < 0) { - shutdown(sock, 2); - close(sock); - std::ostringstream msg; - msg << "Couldn't listen (error: " << errno << ")"; - throw std::runtime_error(msg.str()); - } - for (int i = 0; i < peers.size() - rank_ - 1; i++) { - int peer_socket = accept(sock, nullptr, nullptr); - if (peer_socket < 0) { - shutdown(sock, 2); - close(sock); - std::ostringstream msg; - msg << "Accept failed (error: " << errno << ")"; - throw std::runtime_error(msg.str()); - } - sockets_[peers.size() - 1 - i] = peer_socket; - } - - // Close the listening socket - shutdown(sock, 2); - close(sock); + if (peers_.size() == 0) { + return; } - // Connect to the peers with smaller rank - for (int i = 0; i < rank_; i++) { - sockets_[i] = socket(AF_INET, SOCK_STREAM, 0); - if (sockets_[i] < 0) { - std::ostringstream msg; - msg << "Couldn't create socket (error: " << errno << ")"; - throw std::runtime_error(msg.str()); - } - for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) { - if (attempt > 0) { - int wait = (1 << (attempt - 1)) * CONN_WAIT; - std::this_thread::sleep_for(std::chrono::milliseconds(wait)); - } - success = connect(sockets_[i], peers[i].sockaddr(), peers[i].len); - if (success == 0) { - break; - } - } - if (success < 0) { - std::ostringstream msg; - msg << "Couldn't connect (rank: " << rank_ << " to: " << i - << " error: " << errno << ")"; - throw std::runtime_error(msg.str()); - } + // Make the socket + socket_ = socket(PF_NDRV, SOCK_RAW, 0); + if (socket_ < 0) { + std::ostringstream msg; + msg << "Couldn't create socket (error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + + // Make the address to bind the socket + std::copy(interface.begin(), interface.end(), (char*)sockaddr_.snd_name); + sockaddr_.snd_family = PF_NDRV; + sockaddr_.snd_len = sizeof(sockaddr_); + if (bind(socket_, (sockaddr*)&sockaddr_, sizeof(sockaddr_)) < 0) { + std::ostringstream msg; + msg << "Couldn't bind socket (error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + + // Tell the kernel to filter and select for ETHER_TYPE + ndrv_protocol_desc desc; + ndrv_demux_desc demux_desc; + desc.version = NDRV_PROTOCOL_DESC_VERS; + desc.protocol_family = ETHER_TYPE; + desc.demux_count = 1; + desc.demux_list = &demux_desc; + demux_desc.type = NDRV_DEMUXTYPE_ETHERTYPE; + demux_desc.length = sizeof(uint16_t); + demux_desc.data.ether_type = ETHER_TYPE_NTOHS; + if (setsockopt( + socket_, SOL_NDRVPROTO, NDRV_SETDMXSPEC, &desc, sizeof(desc)) < 0) { + std::ostringstream msg; + msg << "Couldn't set socket option (error: " << errno << ")"; + throw std::runtime_error(msg.str()); } } ~GroupImpl() { - if (global_) { - for (int sock : sockets_) { - shutdown(sock, 2); - close(sock); - } + if (global_ && socket_ > 0) { + close(socket_); } } @@ -279,32 +232,57 @@ struct GroupImpl { } int size() { - return std::max(sockets_.size(), 1ul); + return std::max(peers_.size(), 1ul); + } + + void send_packet(const char* buf, size_t len, int dst) { + char packet[1500]; + peers_[dst].to_buffer(packet); + peers_[rank_].to_buffer(packet + sizeof(mac_address)); + memcpy(packet + 2 * sizeof(mac_address), ÐER_TYPE_NTOHS, sizeof(ETHER_TYPE_NTOHS)); + constexpr int header_len = 2 * sizeof(mac_address) + 2; + memcpy(packet + header_len, buf, len); + int r = sendto( + socket_, + packet, + len + header_len, + 0, + (sockaddr*)&sockaddr_, + sizeof(sockaddr_)); + if (r < 0) { + std::ostringstream msg; + msg << "Send failed (error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + } + + void recv_packet(char* buf, size_t len, int src) { + char packet[1500]; + constexpr int header_len = 2 * sizeof(mac_address) + 2; + int r = ::recv(socket_, packet, len + header_len, 0); + if (r < 0) { + std::ostringstream msg; + msg << "Send failed (error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + memcpy(buf, packet + header_len, len); } void send(const char* buf, size_t len, int dst) { while (len > 0) { - ssize_t r = ::send(sockets_[dst], buf, len, 0); - if (r <= 0) { - std::ostringstream msg; - msg << "Send of " << len << " bytes failed (errno: " << errno << ")"; - throw std::runtime_error(msg.str()); - } - buf += r; - len -= r; + size_t l = std::min(len, PACKET_SIZE); + send_packet(buf, l, dst); + buf += l; + len -= l; } } void recv(char* buf, size_t len, int src) { while (len > 0) { - ssize_t r = ::recv(sockets_[src], buf, len, 0); - if (r <= 0) { - std::ostringstream msg; - msg << "Recv of " << len << " bytes failed (errno: " << errno << ")"; - throw std::runtime_error(msg.str()); - } - buf += r; - len -= r; + size_t l = std::min(len, PACKET_SIZE); + recv_packet(buf, l, src); + buf += l; + len -= l; } } @@ -364,7 +342,9 @@ struct GroupImpl { int rank_; bool global_; ThreadPool pool_; - std::vector sockets_; + std::vector peers_; + sockaddr_ndrv sockaddr_; + int socket_; }; } // namespace @@ -389,7 +369,7 @@ Group init(bool strict /* = false */) { static std::shared_ptr global_group = nullptr; if (global_group == nullptr) { - auto peers = load_peers(); + auto [iface, peers] = parse_config(); int rank = 0; if (const char* rank_buf = std::getenv("MLX_RANK")) { rank = std::atoi(rank_buf); @@ -399,7 +379,8 @@ Group init(bool strict /* = false */) { throw std::runtime_error("Can't initialize distributed"); } } - global_group = std::make_shared(std::move(peers), rank, true); + global_group = + std::make_shared(iface, std::move(peers), rank, true); } return Group(global_group); }