mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Raw sockets
This commit is contained in:
parent
2e267bd6a8
commit
3fe98bacc7
@ -2,6 +2,7 @@
|
||||
|
||||
#include <arpa/inet.h>
|
||||
#include <json.hpp>
|
||||
#include <net/ndrv.h>
|
||||
#include <netdb.h>
|
||||
#include <sys/socket.h>
|
||||
#include <unistd.h>
|
||||
@ -73,9 +74,9 @@
|
||||
} break; \
|
||||
}
|
||||
|
||||
constexpr const size_t PACKET_SIZE = 262144;
|
||||
constexpr const int CONN_ATTEMPTS = 5;
|
||||
constexpr const int CONN_WAIT = 1000;
|
||||
constexpr const size_t PACKET_SIZE = 1408;
|
||||
constexpr const uint16_t ETHER_TYPE = 32923;
|
||||
constexpr const uint16_t ETHER_TYPE_NTOHS = ntohs(ETHER_TYPE);
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
@ -107,170 +108,122 @@ array ensure_row_contiguous(const array& arr) {
|
||||
}
|
||||
}
|
||||
|
||||
struct address_t {
|
||||
sockaddr_storage addr;
|
||||
socklen_t len;
|
||||
struct mac_address {
|
||||
uint8_t raw[6] = {0};
|
||||
|
||||
const sockaddr* sockaddr() {
|
||||
return (struct sockaddr*)&addr;
|
||||
mac_address(const std::string& address) {
|
||||
auto hex_to_int = [](const char c) -> uint8_t {
|
||||
if (c >= 'a' && c <= 'f') {
|
||||
return c - 'a' + 10;
|
||||
}
|
||||
|
||||
if (c >= 'A' && c <= 'F') {
|
||||
return c - 'A' + 10;
|
||||
}
|
||||
|
||||
if (c >= '0' && c <= '9') {
|
||||
return c - '0';
|
||||
}
|
||||
|
||||
return 0;
|
||||
};
|
||||
|
||||
int idx = 0;
|
||||
int cnt = 0;
|
||||
for (const auto c : address) {
|
||||
if (c == ':') {
|
||||
idx += 1;
|
||||
cnt = 0;
|
||||
if (idx >= 6) {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
raw[idx] <<= 4 * cnt;
|
||||
raw[idx] += hex_to_int(c);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void to_buffer(char* buf) {
|
||||
for (int i = 0; i < 6; i++) {
|
||||
buf[i] = ((char*)raw)[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
address_t parse_address(std::string ip, std::string port) {
|
||||
struct addrinfo hints, *res;
|
||||
memset(&hints, 0, sizeof(hints));
|
||||
hints.ai_family = AF_UNSPEC;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
|
||||
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
|
||||
if (status != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Can't parse peer address " << ip << ":" << port;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
address_t result;
|
||||
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
|
||||
result.len = res->ai_addrlen;
|
||||
freeaddrinfo(res);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<address_t> load_peers() {
|
||||
std::vector<address_t> peers;
|
||||
std::pair<std::string, std::vector<mac_address>> parse_config() {
|
||||
std::vector<mac_address> peers;
|
||||
std::ifstream f;
|
||||
|
||||
if (const char* hostfile_buf = std::getenv("MLX_HOSTFILE")) {
|
||||
f.open(hostfile_buf);
|
||||
} else {
|
||||
return peers;
|
||||
return {"lo0", peers};
|
||||
}
|
||||
|
||||
json hosts = json::parse(f);
|
||||
for (auto& h : hosts) {
|
||||
peers.push_back(std::move(parse_address(
|
||||
h["ip"].template get<std::string>(),
|
||||
h["port"].template get<std::string>())));
|
||||
json config = json::parse(f);
|
||||
for (auto& h : config["peers"]) {
|
||||
peers.emplace_back(h.get<std::string>());
|
||||
}
|
||||
|
||||
return peers;
|
||||
return {config["interface"].get<std::string>(), peers};
|
||||
}
|
||||
|
||||
struct GroupImpl {
|
||||
GroupImpl(std::vector<address_t> peers, int rank, bool global)
|
||||
: rank_(rank), global_(global), pool_(4), sockets_(peers.size(), -1) {
|
||||
if (rank_ > 0 && rank_ >= peers.size()) {
|
||||
GroupImpl(
|
||||
const std::string& interface,
|
||||
std::vector<mac_address> peers,
|
||||
int rank,
|
||||
bool global)
|
||||
: rank_(rank), global_(global), pool_(1), peers_(std::move(peers)) {
|
||||
if (rank_ > 0 && rank_ >= peers_.size()) {
|
||||
throw std::runtime_error(
|
||||
"Rank cannot be larger than the size of the group");
|
||||
}
|
||||
|
||||
int success;
|
||||
|
||||
// If we are expecting anyone to connect to us
|
||||
if (rank_ + 1 < peers.size()) {
|
||||
// Create the socket to wait for connections from the peers
|
||||
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sock < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Couldn't create socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Make sure we can launch immediately after shutdown by setting the
|
||||
// reuseaddr option so that we don't get address already in use errors
|
||||
int enable = 1;
|
||||
success =
|
||||
setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "Couldn't enable reuseaddr (rank: " << rank_
|
||||
<< " error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
success =
|
||||
setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "Couldn't enable reuseport (rank: " << rank_
|
||||
<< " error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Bind it to the port
|
||||
success = bind(sock, peers[rank_].sockaddr(), peers[rank_].len);
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "Couldn't bind socket (rank: " << rank_ << " error: " << errno
|
||||
<< ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Wait for connections
|
||||
success = listen(sock, 0);
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "Couldn't listen (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
for (int i = 0; i < peers.size() - rank_ - 1; i++) {
|
||||
int peer_socket = accept(sock, nullptr, nullptr);
|
||||
if (peer_socket < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "Accept failed (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
sockets_[peers.size() - 1 - i] = peer_socket;
|
||||
}
|
||||
|
||||
// Close the listening socket
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
if (peers_.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Connect to the peers with smaller rank
|
||||
for (int i = 0; i < rank_; i++) {
|
||||
sockets_[i] = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sockets_[i] < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Couldn't create socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) {
|
||||
if (attempt > 0) {
|
||||
int wait = (1 << (attempt - 1)) * CONN_WAIT;
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
|
||||
}
|
||||
success = connect(sockets_[i], peers[i].sockaddr(), peers[i].len);
|
||||
if (success == 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (success < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Couldn't connect (rank: " << rank_ << " to: " << i
|
||||
<< " error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
// Make the socket
|
||||
socket_ = socket(PF_NDRV, SOCK_RAW, 0);
|
||||
if (socket_ < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Couldn't create socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Make the address to bind the socket
|
||||
std::copy(interface.begin(), interface.end(), (char*)sockaddr_.snd_name);
|
||||
sockaddr_.snd_family = PF_NDRV;
|
||||
sockaddr_.snd_len = sizeof(sockaddr_);
|
||||
if (bind(socket_, (sockaddr*)&sockaddr_, sizeof(sockaddr_)) < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Couldn't bind socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Tell the kernel to filter and select for ETHER_TYPE
|
||||
ndrv_protocol_desc desc;
|
||||
ndrv_demux_desc demux_desc;
|
||||
desc.version = NDRV_PROTOCOL_DESC_VERS;
|
||||
desc.protocol_family = ETHER_TYPE;
|
||||
desc.demux_count = 1;
|
||||
desc.demux_list = &demux_desc;
|
||||
demux_desc.type = NDRV_DEMUXTYPE_ETHERTYPE;
|
||||
demux_desc.length = sizeof(uint16_t);
|
||||
demux_desc.data.ether_type = ETHER_TYPE_NTOHS;
|
||||
if (setsockopt(
|
||||
socket_, SOL_NDRVPROTO, NDRV_SETDMXSPEC, &desc, sizeof(desc)) < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Couldn't set socket option (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
~GroupImpl() {
|
||||
if (global_) {
|
||||
for (int sock : sockets_) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
}
|
||||
if (global_ && socket_ > 0) {
|
||||
close(socket_);
|
||||
}
|
||||
}
|
||||
|
||||
@ -279,32 +232,57 @@ struct GroupImpl {
|
||||
}
|
||||
|
||||
int size() {
|
||||
return std::max(sockets_.size(), 1ul);
|
||||
return std::max(peers_.size(), 1ul);
|
||||
}
|
||||
|
||||
void send_packet(const char* buf, size_t len, int dst) {
|
||||
char packet[1500];
|
||||
peers_[dst].to_buffer(packet);
|
||||
peers_[rank_].to_buffer(packet + sizeof(mac_address));
|
||||
memcpy(packet + 2 * sizeof(mac_address), ÐER_TYPE_NTOHS, sizeof(ETHER_TYPE_NTOHS));
|
||||
constexpr int header_len = 2 * sizeof(mac_address) + 2;
|
||||
memcpy(packet + header_len, buf, len);
|
||||
int r = sendto(
|
||||
socket_,
|
||||
packet,
|
||||
len + header_len,
|
||||
0,
|
||||
(sockaddr*)&sockaddr_,
|
||||
sizeof(sockaddr_));
|
||||
if (r < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Send failed (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
void recv_packet(char* buf, size_t len, int src) {
|
||||
char packet[1500];
|
||||
constexpr int header_len = 2 * sizeof(mac_address) + 2;
|
||||
int r = ::recv(socket_, packet, len + header_len, 0);
|
||||
if (r < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Send failed (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
memcpy(buf, packet + header_len, len);
|
||||
}
|
||||
|
||||
void send(const char* buf, size_t len, int dst) {
|
||||
while (len > 0) {
|
||||
ssize_t r = ::send(sockets_[dst], buf, len, 0);
|
||||
if (r <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Send of " << len << " bytes failed (errno: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
buf += r;
|
||||
len -= r;
|
||||
size_t l = std::min(len, PACKET_SIZE);
|
||||
send_packet(buf, l, dst);
|
||||
buf += l;
|
||||
len -= l;
|
||||
}
|
||||
}
|
||||
|
||||
void recv(char* buf, size_t len, int src) {
|
||||
while (len > 0) {
|
||||
ssize_t r = ::recv(sockets_[src], buf, len, 0);
|
||||
if (r <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Recv of " << len << " bytes failed (errno: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
buf += r;
|
||||
len -= r;
|
||||
size_t l = std::min(len, PACKET_SIZE);
|
||||
recv_packet(buf, l, src);
|
||||
buf += l;
|
||||
len -= l;
|
||||
}
|
||||
}
|
||||
|
||||
@ -364,7 +342,9 @@ struct GroupImpl {
|
||||
int rank_;
|
||||
bool global_;
|
||||
ThreadPool pool_;
|
||||
std::vector<int> sockets_;
|
||||
std::vector<mac_address> peers_;
|
||||
sockaddr_ndrv sockaddr_;
|
||||
int socket_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@ -389,7 +369,7 @@ Group init(bool strict /* = false */) {
|
||||
static std::shared_ptr<GroupImpl> global_group = nullptr;
|
||||
|
||||
if (global_group == nullptr) {
|
||||
auto peers = load_peers();
|
||||
auto [iface, peers] = parse_config();
|
||||
int rank = 0;
|
||||
if (const char* rank_buf = std::getenv("MLX_RANK")) {
|
||||
rank = std::atoi(rank_buf);
|
||||
@ -399,7 +379,8 @@ Group init(bool strict /* = false */) {
|
||||
throw std::runtime_error("Can't initialize distributed");
|
||||
}
|
||||
}
|
||||
global_group = std::make_shared<GroupImpl>(std::move(peers), rank, true);
|
||||
global_group =
|
||||
std::make_shared<GroupImpl>(iface, std::move(peers), rank, true);
|
||||
}
|
||||
return Group(global_group);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user