mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Raw sockets
This commit is contained in:
		| @@ -2,6 +2,7 @@ | ||||
|  | ||||
| #include <arpa/inet.h> | ||||
| #include <json.hpp> | ||||
| #include <net/ndrv.h> | ||||
| #include <netdb.h> | ||||
| #include <sys/socket.h> | ||||
| #include <unistd.h> | ||||
| @@ -73,9 +74,9 @@ | ||||
|     } break;                 \ | ||||
|   } | ||||
|  | ||||
| constexpr const size_t PACKET_SIZE = 262144; | ||||
| constexpr const int CONN_ATTEMPTS = 5; | ||||
| constexpr const int CONN_WAIT = 1000; | ||||
| constexpr const size_t PACKET_SIZE = 1408; | ||||
| constexpr const uint16_t ETHER_TYPE = 32923; | ||||
| constexpr const uint16_t ETHER_TYPE_NTOHS = ntohs(ETHER_TYPE); | ||||
|  | ||||
| using json = nlohmann::json; | ||||
|  | ||||
| @@ -107,170 +108,122 @@ array ensure_row_contiguous(const array& arr) { | ||||
|   } | ||||
| } | ||||
|  | ||||
| struct address_t { | ||||
|   sockaddr_storage addr; | ||||
|   socklen_t len; | ||||
| struct mac_address { | ||||
|   uint8_t raw[6] = {0}; | ||||
|  | ||||
|   const sockaddr* sockaddr() { | ||||
|     return (struct sockaddr*)&addr; | ||||
|   mac_address(const std::string& address) { | ||||
|     auto hex_to_int = [](const char c) -> uint8_t { | ||||
|       if (c >= 'a' && c <= 'f') { | ||||
|         return c - 'a' + 10; | ||||
|       } | ||||
|  | ||||
|       if (c >= 'A' && c <= 'F') { | ||||
|         return c - 'A' + 10; | ||||
|       } | ||||
|  | ||||
|       if (c >= '0' && c <= '9') { | ||||
|         return c - '0'; | ||||
|       } | ||||
|  | ||||
|       return 0; | ||||
|     }; | ||||
|  | ||||
|     int idx = 0; | ||||
|     int cnt = 0; | ||||
|     for (const auto c : address) { | ||||
|       if (c == ':') { | ||||
|         idx += 1; | ||||
|         cnt = 0; | ||||
|         if (idx >= 6) { | ||||
|           break; | ||||
|         } | ||||
|       } else { | ||||
|         raw[idx] <<= 4 * cnt; | ||||
|         raw[idx] += hex_to_int(c); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   void to_buffer(char* buf) { | ||||
|     for (int i = 0; i < 6; i++) { | ||||
|       buf[i] = ((char*)raw)[i]; | ||||
|     } | ||||
|   } | ||||
| }; | ||||
|  | ||||
| address_t parse_address(std::string ip, std::string port) { | ||||
|   struct addrinfo hints, *res; | ||||
|   memset(&hints, 0, sizeof(hints)); | ||||
|   hints.ai_family = AF_UNSPEC; | ||||
|   hints.ai_socktype = SOCK_STREAM; | ||||
|  | ||||
|   int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res); | ||||
|   if (status != 0) { | ||||
|     std::ostringstream msg; | ||||
|     msg << "Can't parse peer address " << ip << ":" << port; | ||||
|     throw std::runtime_error(msg.str()); | ||||
|   } | ||||
|  | ||||
|   address_t result; | ||||
|   memcpy(&result.addr, res->ai_addr, res->ai_addrlen); | ||||
|   result.len = res->ai_addrlen; | ||||
|   freeaddrinfo(res); | ||||
|  | ||||
|   return result; | ||||
| } | ||||
|  | ||||
| std::vector<address_t> load_peers() { | ||||
|   std::vector<address_t> peers; | ||||
| std::pair<std::string, std::vector<mac_address>> parse_config() { | ||||
|   std::vector<mac_address> peers; | ||||
|   std::ifstream f; | ||||
|  | ||||
|   if (const char* hostfile_buf = std::getenv("MLX_HOSTFILE")) { | ||||
|     f.open(hostfile_buf); | ||||
|   } else { | ||||
|     return peers; | ||||
|     return {"lo0", peers}; | ||||
|   } | ||||
|  | ||||
|   json hosts = json::parse(f); | ||||
|   for (auto& h : hosts) { | ||||
|     peers.push_back(std::move(parse_address( | ||||
|         h["ip"].template get<std::string>(), | ||||
|         h["port"].template get<std::string>()))); | ||||
|   json config = json::parse(f); | ||||
|   for (auto& h : config["peers"]) { | ||||
|     peers.emplace_back(h.get<std::string>()); | ||||
|   } | ||||
|  | ||||
|   return peers; | ||||
|   return {config["interface"].get<std::string>(), peers}; | ||||
| } | ||||
|  | ||||
| struct GroupImpl { | ||||
|   GroupImpl(std::vector<address_t> peers, int rank, bool global) | ||||
|       : rank_(rank), global_(global), pool_(4), sockets_(peers.size(), -1) { | ||||
|     if (rank_ > 0 && rank_ >= peers.size()) { | ||||
|   GroupImpl( | ||||
|       const std::string& interface, | ||||
|       std::vector<mac_address> peers, | ||||
|       int rank, | ||||
|       bool global) | ||||
|       : rank_(rank), global_(global), pool_(1), peers_(std::move(peers)) { | ||||
|     if (rank_ > 0 && rank_ >= peers_.size()) { | ||||
|       throw std::runtime_error( | ||||
|           "Rank cannot be larger than the size of the group"); | ||||
|     } | ||||
|  | ||||
|     int success; | ||||
|  | ||||
|     // If we are expecting anyone to connect to us | ||||
|     if (rank_ + 1 < peers.size()) { | ||||
|       // Create the socket to wait for connections from the peers | ||||
|       int sock = socket(AF_INET, SOCK_STREAM, 0); | ||||
|       if (sock < 0) { | ||||
|         std::ostringstream msg; | ||||
|         msg << "Couldn't create socket (error: " << errno << ")"; | ||||
|         throw std::runtime_error(msg.str()); | ||||
|       } | ||||
|  | ||||
|       // Make sure we can launch immediately after shutdown by setting the | ||||
|       // reuseaddr option so that we don't get address already in use errors | ||||
|       int enable = 1; | ||||
|       success = | ||||
|           setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)); | ||||
|       if (success < 0) { | ||||
|         shutdown(sock, 2); | ||||
|         close(sock); | ||||
|         std::ostringstream msg; | ||||
|         msg << "Couldn't enable reuseaddr (rank: " << rank_ | ||||
|             << " error: " << errno << ")"; | ||||
|         throw std::runtime_error(msg.str()); | ||||
|       } | ||||
|       success = | ||||
|           setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int)); | ||||
|       if (success < 0) { | ||||
|         shutdown(sock, 2); | ||||
|         close(sock); | ||||
|         std::ostringstream msg; | ||||
|         msg << "Couldn't enable reuseport (rank: " << rank_ | ||||
|             << " error: " << errno << ")"; | ||||
|         throw std::runtime_error(msg.str()); | ||||
|       } | ||||
|  | ||||
|       // Bind it to the port | ||||
|       success = bind(sock, peers[rank_].sockaddr(), peers[rank_].len); | ||||
|       if (success < 0) { | ||||
|         shutdown(sock, 2); | ||||
|         close(sock); | ||||
|         std::ostringstream msg; | ||||
|         msg << "Couldn't bind socket (rank: " << rank_ << " error: " << errno | ||||
|             << ")"; | ||||
|         throw std::runtime_error(msg.str()); | ||||
|       } | ||||
|  | ||||
|       // Wait for connections | ||||
|       success = listen(sock, 0); | ||||
|       if (success < 0) { | ||||
|         shutdown(sock, 2); | ||||
|         close(sock); | ||||
|         std::ostringstream msg; | ||||
|         msg << "Couldn't listen (error: " << errno << ")"; | ||||
|         throw std::runtime_error(msg.str()); | ||||
|       } | ||||
|       for (int i = 0; i < peers.size() - rank_ - 1; i++) { | ||||
|         int peer_socket = accept(sock, nullptr, nullptr); | ||||
|         if (peer_socket < 0) { | ||||
|           shutdown(sock, 2); | ||||
|           close(sock); | ||||
|           std::ostringstream msg; | ||||
|           msg << "Accept failed (error: " << errno << ")"; | ||||
|           throw std::runtime_error(msg.str()); | ||||
|         } | ||||
|         sockets_[peers.size() - 1 - i] = peer_socket; | ||||
|       } | ||||
|  | ||||
|       // Close the listening socket | ||||
|       shutdown(sock, 2); | ||||
|       close(sock); | ||||
|     if (peers_.size() == 0) { | ||||
|       return; | ||||
|     } | ||||
|  | ||||
|     // Connect to the peers with smaller rank | ||||
|     for (int i = 0; i < rank_; i++) { | ||||
|       sockets_[i] = socket(AF_INET, SOCK_STREAM, 0); | ||||
|       if (sockets_[i] < 0) { | ||||
|         std::ostringstream msg; | ||||
|         msg << "Couldn't create socket (error: " << errno << ")"; | ||||
|         throw std::runtime_error(msg.str()); | ||||
|       } | ||||
|       for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) { | ||||
|         if (attempt > 0) { | ||||
|           int wait = (1 << (attempt - 1)) * CONN_WAIT; | ||||
|           std::this_thread::sleep_for(std::chrono::milliseconds(wait)); | ||||
|         } | ||||
|         success = connect(sockets_[i], peers[i].sockaddr(), peers[i].len); | ||||
|         if (success == 0) { | ||||
|           break; | ||||
|         } | ||||
|       } | ||||
|       if (success < 0) { | ||||
|         std::ostringstream msg; | ||||
|         msg << "Couldn't connect (rank: " << rank_ << " to: " << i | ||||
|             << " error: " << errno << ")"; | ||||
|         throw std::runtime_error(msg.str()); | ||||
|       } | ||||
|     // Make the socket | ||||
|     socket_ = socket(PF_NDRV, SOCK_RAW, 0); | ||||
|     if (socket_ < 0) { | ||||
|       std::ostringstream msg; | ||||
|       msg << "Couldn't create socket (error: " << errno << ")"; | ||||
|       throw std::runtime_error(msg.str()); | ||||
|     } | ||||
|  | ||||
|     // Make the address to bind the socket | ||||
|     std::copy(interface.begin(), interface.end(), (char*)sockaddr_.snd_name); | ||||
|     sockaddr_.snd_family = PF_NDRV; | ||||
|     sockaddr_.snd_len = sizeof(sockaddr_); | ||||
|     if (bind(socket_, (sockaddr*)&sockaddr_, sizeof(sockaddr_)) < 0) { | ||||
|       std::ostringstream msg; | ||||
|       msg << "Couldn't bind socket (error: " << errno << ")"; | ||||
|       throw std::runtime_error(msg.str()); | ||||
|     } | ||||
|  | ||||
|     // Tell the kernel to filter and select for ETHER_TYPE | ||||
|     ndrv_protocol_desc desc; | ||||
|     ndrv_demux_desc demux_desc; | ||||
|     desc.version = NDRV_PROTOCOL_DESC_VERS; | ||||
|     desc.protocol_family = ETHER_TYPE; | ||||
|     desc.demux_count = 1; | ||||
|     desc.demux_list = &demux_desc; | ||||
|     demux_desc.type = NDRV_DEMUXTYPE_ETHERTYPE; | ||||
|     demux_desc.length = sizeof(uint16_t); | ||||
|     demux_desc.data.ether_type = ETHER_TYPE_NTOHS; | ||||
|     if (setsockopt( | ||||
|             socket_, SOL_NDRVPROTO, NDRV_SETDMXSPEC, &desc, sizeof(desc)) < 0) { | ||||
|       std::ostringstream msg; | ||||
|       msg << "Couldn't set socket option (error: " << errno << ")"; | ||||
|       throw std::runtime_error(msg.str()); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   ~GroupImpl() { | ||||
|     if (global_) { | ||||
|       for (int sock : sockets_) { | ||||
|         shutdown(sock, 2); | ||||
|         close(sock); | ||||
|       } | ||||
|     if (global_ && socket_ > 0) { | ||||
|       close(socket_); | ||||
|     } | ||||
|   } | ||||
|  | ||||
| @@ -279,32 +232,57 @@ struct GroupImpl { | ||||
|   } | ||||
|  | ||||
|   int size() { | ||||
|     return std::max(sockets_.size(), 1ul); | ||||
|     return std::max(peers_.size(), 1ul); | ||||
|   } | ||||
|  | ||||
|   void send_packet(const char* buf, size_t len, int dst) { | ||||
|     char packet[1500]; | ||||
|     peers_[dst].to_buffer(packet); | ||||
|     peers_[rank_].to_buffer(packet + sizeof(mac_address)); | ||||
|     memcpy(packet + 2 * sizeof(mac_address), ÐER_TYPE_NTOHS, sizeof(ETHER_TYPE_NTOHS)); | ||||
|     constexpr int header_len = 2 * sizeof(mac_address) + 2; | ||||
|     memcpy(packet + header_len, buf, len); | ||||
|     int r = sendto( | ||||
|         socket_, | ||||
|         packet, | ||||
|         len + header_len, | ||||
|         0, | ||||
|         (sockaddr*)&sockaddr_, | ||||
|         sizeof(sockaddr_)); | ||||
|     if (r < 0) { | ||||
|       std::ostringstream msg; | ||||
|       msg << "Send failed (error: " << errno << ")"; | ||||
|       throw std::runtime_error(msg.str()); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   void recv_packet(char* buf, size_t len, int src) { | ||||
|     char packet[1500]; | ||||
|     constexpr int header_len = 2 * sizeof(mac_address) + 2; | ||||
|     int r = ::recv(socket_, packet, len + header_len, 0); | ||||
|     if (r < 0) { | ||||
|       std::ostringstream msg; | ||||
|       msg << "Send failed (error: " << errno << ")"; | ||||
|       throw std::runtime_error(msg.str()); | ||||
|     } | ||||
|     memcpy(buf, packet + header_len, len); | ||||
|   } | ||||
|  | ||||
|   void send(const char* buf, size_t len, int dst) { | ||||
|     while (len > 0) { | ||||
|       ssize_t r = ::send(sockets_[dst], buf, len, 0); | ||||
|       if (r <= 0) { | ||||
|         std::ostringstream msg; | ||||
|         msg << "Send of " << len << " bytes failed (errno: " << errno << ")"; | ||||
|         throw std::runtime_error(msg.str()); | ||||
|       } | ||||
|       buf += r; | ||||
|       len -= r; | ||||
|       size_t l = std::min(len, PACKET_SIZE); | ||||
|       send_packet(buf, l, dst); | ||||
|       buf += l; | ||||
|       len -= l; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   void recv(char* buf, size_t len, int src) { | ||||
|     while (len > 0) { | ||||
|       ssize_t r = ::recv(sockets_[src], buf, len, 0); | ||||
|       if (r <= 0) { | ||||
|         std::ostringstream msg; | ||||
|         msg << "Recv of " << len << " bytes failed (errno: " << errno << ")"; | ||||
|         throw std::runtime_error(msg.str()); | ||||
|       } | ||||
|       buf += r; | ||||
|       len -= r; | ||||
|       size_t l = std::min(len, PACKET_SIZE); | ||||
|       recv_packet(buf, l, src); | ||||
|       buf += l; | ||||
|       len -= l; | ||||
|     } | ||||
|   } | ||||
|  | ||||
| @@ -364,7 +342,9 @@ struct GroupImpl { | ||||
|   int rank_; | ||||
|   bool global_; | ||||
|   ThreadPool pool_; | ||||
|   std::vector<int> sockets_; | ||||
|   std::vector<mac_address> peers_; | ||||
|   sockaddr_ndrv sockaddr_; | ||||
|   int socket_; | ||||
| }; | ||||
|  | ||||
| } // namespace | ||||
| @@ -389,7 +369,7 @@ Group init(bool strict /* = false */) { | ||||
|   static std::shared_ptr<GroupImpl> global_group = nullptr; | ||||
|  | ||||
|   if (global_group == nullptr) { | ||||
|     auto peers = load_peers(); | ||||
|     auto [iface, peers] = parse_config(); | ||||
|     int rank = 0; | ||||
|     if (const char* rank_buf = std::getenv("MLX_RANK")) { | ||||
|       rank = std::atoi(rank_buf); | ||||
| @@ -399,7 +379,8 @@ Group init(bool strict /* = false */) { | ||||
|         throw std::runtime_error("Can't initialize distributed"); | ||||
|       } | ||||
|     } | ||||
|     global_group = std::make_shared<GroupImpl>(std::move(peers), rank, true); | ||||
|     global_group = | ||||
|         std::make_shared<GroupImpl>(iface, std::move(peers), rank, true); | ||||
|   } | ||||
|   return Group(global_group); | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Angelos Katharopoulos
					Angelos Katharopoulos