mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Raw sockets
This commit is contained in:
		@@ -2,6 +2,7 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
#include <arpa/inet.h>
 | 
					#include <arpa/inet.h>
 | 
				
			||||||
#include <json.hpp>
 | 
					#include <json.hpp>
 | 
				
			||||||
 | 
					#include <net/ndrv.h>
 | 
				
			||||||
#include <netdb.h>
 | 
					#include <netdb.h>
 | 
				
			||||||
#include <sys/socket.h>
 | 
					#include <sys/socket.h>
 | 
				
			||||||
#include <unistd.h>
 | 
					#include <unistd.h>
 | 
				
			||||||
@@ -73,9 +74,9 @@
 | 
				
			|||||||
    } break;                 \
 | 
					    } break;                 \
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
constexpr const size_t PACKET_SIZE = 262144;
 | 
					constexpr const size_t PACKET_SIZE = 1408;
 | 
				
			||||||
constexpr const int CONN_ATTEMPTS = 5;
 | 
					constexpr const uint16_t ETHER_TYPE = 32923;
 | 
				
			||||||
constexpr const int CONN_WAIT = 1000;
 | 
					constexpr const uint16_t ETHER_TYPE_NTOHS = ntohs(ETHER_TYPE);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
using json = nlohmann::json;
 | 
					using json = nlohmann::json;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -107,170 +108,122 @@ array ensure_row_contiguous(const array& arr) {
 | 
				
			|||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct address_t {
 | 
					struct mac_address {
 | 
				
			||||||
  sockaddr_storage addr;
 | 
					  uint8_t raw[6] = {0};
 | 
				
			||||||
  socklen_t len;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  const sockaddr* sockaddr() {
 | 
					  mac_address(const std::string& address) {
 | 
				
			||||||
    return (struct sockaddr*)&addr;
 | 
					    auto hex_to_int = [](const char c) -> uint8_t {
 | 
				
			||||||
 | 
					      if (c >= 'a' && c <= 'f') {
 | 
				
			||||||
 | 
					        return c - 'a' + 10;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      if (c >= 'A' && c <= 'F') {
 | 
				
			||||||
 | 
					        return c - 'A' + 10;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      if (c >= '0' && c <= '9') {
 | 
				
			||||||
 | 
					        return c - '0';
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      return 0;
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    int idx = 0;
 | 
				
			||||||
 | 
					    int cnt = 0;
 | 
				
			||||||
 | 
					    for (const auto c : address) {
 | 
				
			||||||
 | 
					      if (c == ':') {
 | 
				
			||||||
 | 
					        idx += 1;
 | 
				
			||||||
 | 
					        cnt = 0;
 | 
				
			||||||
 | 
					        if (idx >= 6) {
 | 
				
			||||||
 | 
					          break;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      } else {
 | 
				
			||||||
 | 
					        raw[idx] <<= 4 * cnt;
 | 
				
			||||||
 | 
					        raw[idx] += hex_to_int(c);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void to_buffer(char* buf) {
 | 
				
			||||||
 | 
					    for (int i = 0; i < 6; i++) {
 | 
				
			||||||
 | 
					      buf[i] = ((char*)raw)[i];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
address_t parse_address(std::string ip, std::string port) {
 | 
					std::pair<std::string, std::vector<mac_address>> parse_config() {
 | 
				
			||||||
  struct addrinfo hints, *res;
 | 
					  std::vector<mac_address> peers;
 | 
				
			||||||
  memset(&hints, 0, sizeof(hints));
 | 
					 | 
				
			||||||
  hints.ai_family = AF_UNSPEC;
 | 
					 | 
				
			||||||
  hints.ai_socktype = SOCK_STREAM;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
 | 
					 | 
				
			||||||
  if (status != 0) {
 | 
					 | 
				
			||||||
    std::ostringstream msg;
 | 
					 | 
				
			||||||
    msg << "Can't parse peer address " << ip << ":" << port;
 | 
					 | 
				
			||||||
    throw std::runtime_error(msg.str());
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  address_t result;
 | 
					 | 
				
			||||||
  memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
 | 
					 | 
				
			||||||
  result.len = res->ai_addrlen;
 | 
					 | 
				
			||||||
  freeaddrinfo(res);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  return result;
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
std::vector<address_t> load_peers() {
 | 
					 | 
				
			||||||
  std::vector<address_t> peers;
 | 
					 | 
				
			||||||
  std::ifstream f;
 | 
					  std::ifstream f;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (const char* hostfile_buf = std::getenv("MLX_HOSTFILE")) {
 | 
					  if (const char* hostfile_buf = std::getenv("MLX_HOSTFILE")) {
 | 
				
			||||||
    f.open(hostfile_buf);
 | 
					    f.open(hostfile_buf);
 | 
				
			||||||
  } else {
 | 
					  } else {
 | 
				
			||||||
    return peers;
 | 
					    return {"lo0", peers};
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  json hosts = json::parse(f);
 | 
					  json config = json::parse(f);
 | 
				
			||||||
  for (auto& h : hosts) {
 | 
					  for (auto& h : config["peers"]) {
 | 
				
			||||||
    peers.push_back(std::move(parse_address(
 | 
					    peers.emplace_back(h.get<std::string>());
 | 
				
			||||||
        h["ip"].template get<std::string>(),
 | 
					 | 
				
			||||||
        h["port"].template get<std::string>())));
 | 
					 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return peers;
 | 
					  return {config["interface"].get<std::string>(), peers};
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct GroupImpl {
 | 
					struct GroupImpl {
 | 
				
			||||||
  GroupImpl(std::vector<address_t> peers, int rank, bool global)
 | 
					  GroupImpl(
 | 
				
			||||||
      : rank_(rank), global_(global), pool_(4), sockets_(peers.size(), -1) {
 | 
					      const std::string& interface,
 | 
				
			||||||
    if (rank_ > 0 && rank_ >= peers.size()) {
 | 
					      std::vector<mac_address> peers,
 | 
				
			||||||
 | 
					      int rank,
 | 
				
			||||||
 | 
					      bool global)
 | 
				
			||||||
 | 
					      : rank_(rank), global_(global), pool_(1), peers_(std::move(peers)) {
 | 
				
			||||||
 | 
					    if (rank_ > 0 && rank_ >= peers_.size()) {
 | 
				
			||||||
      throw std::runtime_error(
 | 
					      throw std::runtime_error(
 | 
				
			||||||
          "Rank cannot be larger than the size of the group");
 | 
					          "Rank cannot be larger than the size of the group");
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    int success;
 | 
					    if (peers_.size() == 0) {
 | 
				
			||||||
 | 
					      return;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // If we are expecting anyone to connect to us
 | 
					    // Make the socket
 | 
				
			||||||
    if (rank_ + 1 < peers.size()) {
 | 
					    socket_ = socket(PF_NDRV, SOCK_RAW, 0);
 | 
				
			||||||
      // Create the socket to wait for connections from the peers
 | 
					    if (socket_ < 0) {
 | 
				
			||||||
      int sock = socket(AF_INET, SOCK_STREAM, 0);
 | 
					 | 
				
			||||||
      if (sock < 0) {
 | 
					 | 
				
			||||||
      std::ostringstream msg;
 | 
					      std::ostringstream msg;
 | 
				
			||||||
      msg << "Couldn't create socket (error: " << errno << ")";
 | 
					      msg << "Couldn't create socket (error: " << errno << ")";
 | 
				
			||||||
      throw std::runtime_error(msg.str());
 | 
					      throw std::runtime_error(msg.str());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      // Make sure we can launch immediately after shutdown by setting the
 | 
					    // Make the address to bind the socket
 | 
				
			||||||
      // reuseaddr option so that we don't get address already in use errors
 | 
					    std::copy(interface.begin(), interface.end(), (char*)sockaddr_.snd_name);
 | 
				
			||||||
      int enable = 1;
 | 
					    sockaddr_.snd_family = PF_NDRV;
 | 
				
			||||||
      success =
 | 
					    sockaddr_.snd_len = sizeof(sockaddr_);
 | 
				
			||||||
          setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
 | 
					    if (bind(socket_, (sockaddr*)&sockaddr_, sizeof(sockaddr_)) < 0) {
 | 
				
			||||||
      if (success < 0) {
 | 
					 | 
				
			||||||
        shutdown(sock, 2);
 | 
					 | 
				
			||||||
        close(sock);
 | 
					 | 
				
			||||||
      std::ostringstream msg;
 | 
					      std::ostringstream msg;
 | 
				
			||||||
        msg << "Couldn't enable reuseaddr (rank: " << rank_
 | 
					      msg << "Couldn't bind socket (error: " << errno << ")";
 | 
				
			||||||
            << " error: " << errno << ")";
 | 
					 | 
				
			||||||
        throw std::runtime_error(msg.str());
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
      success =
 | 
					 | 
				
			||||||
          setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
 | 
					 | 
				
			||||||
      if (success < 0) {
 | 
					 | 
				
			||||||
        shutdown(sock, 2);
 | 
					 | 
				
			||||||
        close(sock);
 | 
					 | 
				
			||||||
        std::ostringstream msg;
 | 
					 | 
				
			||||||
        msg << "Couldn't enable reuseport (rank: " << rank_
 | 
					 | 
				
			||||||
            << " error: " << errno << ")";
 | 
					 | 
				
			||||||
      throw std::runtime_error(msg.str());
 | 
					      throw std::runtime_error(msg.str());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      // Bind it to the port
 | 
					    // Tell the kernel to filter and select for ETHER_TYPE
 | 
				
			||||||
      success = bind(sock, peers[rank_].sockaddr(), peers[rank_].len);
 | 
					    ndrv_protocol_desc desc;
 | 
				
			||||||
      if (success < 0) {
 | 
					    ndrv_demux_desc demux_desc;
 | 
				
			||||||
        shutdown(sock, 2);
 | 
					    desc.version = NDRV_PROTOCOL_DESC_VERS;
 | 
				
			||||||
        close(sock);
 | 
					    desc.protocol_family = ETHER_TYPE;
 | 
				
			||||||
 | 
					    desc.demux_count = 1;
 | 
				
			||||||
 | 
					    desc.demux_list = &demux_desc;
 | 
				
			||||||
 | 
					    demux_desc.type = NDRV_DEMUXTYPE_ETHERTYPE;
 | 
				
			||||||
 | 
					    demux_desc.length = sizeof(uint16_t);
 | 
				
			||||||
 | 
					    demux_desc.data.ether_type = ETHER_TYPE_NTOHS;
 | 
				
			||||||
 | 
					    if (setsockopt(
 | 
				
			||||||
 | 
					            socket_, SOL_NDRVPROTO, NDRV_SETDMXSPEC, &desc, sizeof(desc)) < 0) {
 | 
				
			||||||
      std::ostringstream msg;
 | 
					      std::ostringstream msg;
 | 
				
			||||||
        msg << "Couldn't bind socket (rank: " << rank_ << " error: " << errno
 | 
					      msg << "Couldn't set socket option (error: " << errno << ")";
 | 
				
			||||||
            << ")";
 | 
					 | 
				
			||||||
      throw std::runtime_error(msg.str());
 | 
					      throw std::runtime_error(msg.str());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					 | 
				
			||||||
      // Wait for connections
 | 
					 | 
				
			||||||
      success = listen(sock, 0);
 | 
					 | 
				
			||||||
      if (success < 0) {
 | 
					 | 
				
			||||||
        shutdown(sock, 2);
 | 
					 | 
				
			||||||
        close(sock);
 | 
					 | 
				
			||||||
        std::ostringstream msg;
 | 
					 | 
				
			||||||
        msg << "Couldn't listen (error: " << errno << ")";
 | 
					 | 
				
			||||||
        throw std::runtime_error(msg.str());
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
      for (int i = 0; i < peers.size() - rank_ - 1; i++) {
 | 
					 | 
				
			||||||
        int peer_socket = accept(sock, nullptr, nullptr);
 | 
					 | 
				
			||||||
        if (peer_socket < 0) {
 | 
					 | 
				
			||||||
          shutdown(sock, 2);
 | 
					 | 
				
			||||||
          close(sock);
 | 
					 | 
				
			||||||
          std::ostringstream msg;
 | 
					 | 
				
			||||||
          msg << "Accept failed (error: " << errno << ")";
 | 
					 | 
				
			||||||
          throw std::runtime_error(msg.str());
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        sockets_[peers.size() - 1 - i] = peer_socket;
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
      // Close the listening socket
 | 
					 | 
				
			||||||
      shutdown(sock, 2);
 | 
					 | 
				
			||||||
      close(sock);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // Connect to the peers with smaller rank
 | 
					 | 
				
			||||||
    for (int i = 0; i < rank_; i++) {
 | 
					 | 
				
			||||||
      sockets_[i] = socket(AF_INET, SOCK_STREAM, 0);
 | 
					 | 
				
			||||||
      if (sockets_[i] < 0) {
 | 
					 | 
				
			||||||
        std::ostringstream msg;
 | 
					 | 
				
			||||||
        msg << "Couldn't create socket (error: " << errno << ")";
 | 
					 | 
				
			||||||
        throw std::runtime_error(msg.str());
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
      for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) {
 | 
					 | 
				
			||||||
        if (attempt > 0) {
 | 
					 | 
				
			||||||
          int wait = (1 << (attempt - 1)) * CONN_WAIT;
 | 
					 | 
				
			||||||
          std::this_thread::sleep_for(std::chrono::milliseconds(wait));
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        success = connect(sockets_[i], peers[i].sockaddr(), peers[i].len);
 | 
					 | 
				
			||||||
        if (success == 0) {
 | 
					 | 
				
			||||||
          break;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
      if (success < 0) {
 | 
					 | 
				
			||||||
        std::ostringstream msg;
 | 
					 | 
				
			||||||
        msg << "Couldn't connect (rank: " << rank_ << " to: " << i
 | 
					 | 
				
			||||||
            << " error: " << errno << ")";
 | 
					 | 
				
			||||||
        throw std::runtime_error(msg.str());
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  ~GroupImpl() {
 | 
					  ~GroupImpl() {
 | 
				
			||||||
    if (global_) {
 | 
					    if (global_ && socket_ > 0) {
 | 
				
			||||||
      for (int sock : sockets_) {
 | 
					      close(socket_);
 | 
				
			||||||
        shutdown(sock, 2);
 | 
					 | 
				
			||||||
        close(sock);
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -279,32 +232,57 @@ struct GroupImpl {
 | 
				
			|||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  int size() {
 | 
					  int size() {
 | 
				
			||||||
    return std::max(sockets_.size(), 1ul);
 | 
					    return std::max(peers_.size(), 1ul);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void send_packet(const char* buf, size_t len, int dst) {
 | 
				
			||||||
 | 
					    char packet[1500];
 | 
				
			||||||
 | 
					    peers_[dst].to_buffer(packet);
 | 
				
			||||||
 | 
					    peers_[rank_].to_buffer(packet + sizeof(mac_address));
 | 
				
			||||||
 | 
					    memcpy(packet + 2 * sizeof(mac_address), ÐER_TYPE_NTOHS, sizeof(ETHER_TYPE_NTOHS));
 | 
				
			||||||
 | 
					    constexpr int header_len = 2 * sizeof(mac_address) + 2;
 | 
				
			||||||
 | 
					    memcpy(packet + header_len, buf, len);
 | 
				
			||||||
 | 
					    int r = sendto(
 | 
				
			||||||
 | 
					        socket_,
 | 
				
			||||||
 | 
					        packet,
 | 
				
			||||||
 | 
					        len + header_len,
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					        (sockaddr*)&sockaddr_,
 | 
				
			||||||
 | 
					        sizeof(sockaddr_));
 | 
				
			||||||
 | 
					    if (r < 0) {
 | 
				
			||||||
 | 
					      std::ostringstream msg;
 | 
				
			||||||
 | 
					      msg << "Send failed (error: " << errno << ")";
 | 
				
			||||||
 | 
					      throw std::runtime_error(msg.str());
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void recv_packet(char* buf, size_t len, int src) {
 | 
				
			||||||
 | 
					    char packet[1500];
 | 
				
			||||||
 | 
					    constexpr int header_len = 2 * sizeof(mac_address) + 2;
 | 
				
			||||||
 | 
					    int r = ::recv(socket_, packet, len + header_len, 0);
 | 
				
			||||||
 | 
					    if (r < 0) {
 | 
				
			||||||
 | 
					      std::ostringstream msg;
 | 
				
			||||||
 | 
					      msg << "Send failed (error: " << errno << ")";
 | 
				
			||||||
 | 
					      throw std::runtime_error(msg.str());
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    memcpy(buf, packet + header_len, len);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  void send(const char* buf, size_t len, int dst) {
 | 
					  void send(const char* buf, size_t len, int dst) {
 | 
				
			||||||
    while (len > 0) {
 | 
					    while (len > 0) {
 | 
				
			||||||
      ssize_t r = ::send(sockets_[dst], buf, len, 0);
 | 
					      size_t l = std::min(len, PACKET_SIZE);
 | 
				
			||||||
      if (r <= 0) {
 | 
					      send_packet(buf, l, dst);
 | 
				
			||||||
        std::ostringstream msg;
 | 
					      buf += l;
 | 
				
			||||||
        msg << "Send of " << len << " bytes failed (errno: " << errno << ")";
 | 
					      len -= l;
 | 
				
			||||||
        throw std::runtime_error(msg.str());
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
      buf += r;
 | 
					 | 
				
			||||||
      len -= r;
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  void recv(char* buf, size_t len, int src) {
 | 
					  void recv(char* buf, size_t len, int src) {
 | 
				
			||||||
    while (len > 0) {
 | 
					    while (len > 0) {
 | 
				
			||||||
      ssize_t r = ::recv(sockets_[src], buf, len, 0);
 | 
					      size_t l = std::min(len, PACKET_SIZE);
 | 
				
			||||||
      if (r <= 0) {
 | 
					      recv_packet(buf, l, src);
 | 
				
			||||||
        std::ostringstream msg;
 | 
					      buf += l;
 | 
				
			||||||
        msg << "Recv of " << len << " bytes failed (errno: " << errno << ")";
 | 
					      len -= l;
 | 
				
			||||||
        throw std::runtime_error(msg.str());
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
      buf += r;
 | 
					 | 
				
			||||||
      len -= r;
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -364,7 +342,9 @@ struct GroupImpl {
 | 
				
			|||||||
  int rank_;
 | 
					  int rank_;
 | 
				
			||||||
  bool global_;
 | 
					  bool global_;
 | 
				
			||||||
  ThreadPool pool_;
 | 
					  ThreadPool pool_;
 | 
				
			||||||
  std::vector<int> sockets_;
 | 
					  std::vector<mac_address> peers_;
 | 
				
			||||||
 | 
					  sockaddr_ndrv sockaddr_;
 | 
				
			||||||
 | 
					  int socket_;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace
 | 
					} // namespace
 | 
				
			||||||
@@ -389,7 +369,7 @@ Group init(bool strict /* = false */) {
 | 
				
			|||||||
  static std::shared_ptr<GroupImpl> global_group = nullptr;
 | 
					  static std::shared_ptr<GroupImpl> global_group = nullptr;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (global_group == nullptr) {
 | 
					  if (global_group == nullptr) {
 | 
				
			||||||
    auto peers = load_peers();
 | 
					    auto [iface, peers] = parse_config();
 | 
				
			||||||
    int rank = 0;
 | 
					    int rank = 0;
 | 
				
			||||||
    if (const char* rank_buf = std::getenv("MLX_RANK")) {
 | 
					    if (const char* rank_buf = std::getenv("MLX_RANK")) {
 | 
				
			||||||
      rank = std::atoi(rank_buf);
 | 
					      rank = std::atoi(rank_buf);
 | 
				
			||||||
@@ -399,7 +379,8 @@ Group init(bool strict /* = false */) {
 | 
				
			|||||||
        throw std::runtime_error("Can't initialize distributed");
 | 
					        throw std::runtime_error("Can't initialize distributed");
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    global_group = std::make_shared<GroupImpl>(std::move(peers), rank, true);
 | 
					    global_group =
 | 
				
			||||||
 | 
					        std::make_shared<GroupImpl>(iface, std::move(peers), rank, true);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  return Group(global_group);
 | 
					  return Group(global_group);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user