Raw sockets

This commit is contained in:
Angelos Katharopoulos 2024-09-08 23:21:02 -07:00
parent 2e267bd6a8
commit 3fe98bacc7

View File

@ -2,6 +2,7 @@
#include <arpa/inet.h>
#include <json.hpp>
#include <net/ndrv.h>
#include <netdb.h>
#include <sys/socket.h>
#include <unistd.h>
@ -73,9 +74,9 @@
} break; \
}
constexpr const size_t PACKET_SIZE = 262144;
constexpr const int CONN_ATTEMPTS = 5;
constexpr const int CONN_WAIT = 1000;
constexpr const size_t PACKET_SIZE = 1408;
constexpr const uint16_t ETHER_TYPE = 32923;
constexpr const uint16_t ETHER_TYPE_NTOHS = ntohs(ETHER_TYPE);
using json = nlohmann::json;
@ -107,170 +108,122 @@ array ensure_row_contiguous(const array& arr) {
}
}
struct address_t {
sockaddr_storage addr;
socklen_t len;
struct mac_address {
uint8_t raw[6] = {0};
const sockaddr* sockaddr() {
return (struct sockaddr*)&addr;
mac_address(const std::string& address) {
auto hex_to_int = [](const char c) -> uint8_t {
if (c >= 'a' && c <= 'f') {
return c - 'a' + 10;
}
if (c >= 'A' && c <= 'F') {
return c - 'A' + 10;
}
if (c >= '0' && c <= '9') {
return c - '0';
}
return 0;
};
int idx = 0;
int cnt = 0;
for (const auto c : address) {
if (c == ':') {
idx += 1;
cnt = 0;
if (idx >= 6) {
break;
}
} else {
raw[idx] <<= 4 * cnt;
raw[idx] += hex_to_int(c);
}
}
}
void to_buffer(char* buf) {
for (int i = 0; i < 6; i++) {
buf[i] = ((char*)raw)[i];
}
}
};
address_t parse_address(std::string ip, std::string port) {
struct addrinfo hints, *res;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
if (status != 0) {
std::ostringstream msg;
msg << "Can't parse peer address " << ip << ":" << port;
throw std::runtime_error(msg.str());
}
address_t result;
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
result.len = res->ai_addrlen;
freeaddrinfo(res);
return result;
}
std::vector<address_t> load_peers() {
std::vector<address_t> peers;
std::pair<std::string, std::vector<mac_address>> parse_config() {
std::vector<mac_address> peers;
std::ifstream f;
if (const char* hostfile_buf = std::getenv("MLX_HOSTFILE")) {
f.open(hostfile_buf);
} else {
return peers;
return {"lo0", peers};
}
json hosts = json::parse(f);
for (auto& h : hosts) {
peers.push_back(std::move(parse_address(
h["ip"].template get<std::string>(),
h["port"].template get<std::string>())));
json config = json::parse(f);
for (auto& h : config["peers"]) {
peers.emplace_back(h.get<std::string>());
}
return peers;
return {config["interface"].get<std::string>(), peers};
}
struct GroupImpl {
GroupImpl(std::vector<address_t> peers, int rank, bool global)
: rank_(rank), global_(global), pool_(4), sockets_(peers.size(), -1) {
if (rank_ > 0 && rank_ >= peers.size()) {
GroupImpl(
const std::string& interface,
std::vector<mac_address> peers,
int rank,
bool global)
: rank_(rank), global_(global), pool_(1), peers_(std::move(peers)) {
if (rank_ > 0 && rank_ >= peers_.size()) {
throw std::runtime_error(
"Rank cannot be larger than the size of the group");
}
int success;
// If we are expecting anyone to connect to us
if (rank_ + 1 < peers.size()) {
// Create the socket to wait for connections from the peers
int sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << "Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Make sure we can launch immediately after shutdown by setting the
// reuseaddr option so that we don't get address already in use errors
int enable = 1;
success =
setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "Couldn't enable reuseaddr (rank: " << rank_
<< " error: " << errno << ")";
throw std::runtime_error(msg.str());
}
success =
setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "Couldn't enable reuseport (rank: " << rank_
<< " error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Bind it to the port
success = bind(sock, peers[rank_].sockaddr(), peers[rank_].len);
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "Couldn't bind socket (rank: " << rank_ << " error: " << errno
<< ")";
throw std::runtime_error(msg.str());
}
// Wait for connections
success = listen(sock, 0);
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "Couldn't listen (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
for (int i = 0; i < peers.size() - rank_ - 1; i++) {
int peer_socket = accept(sock, nullptr, nullptr);
if (peer_socket < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "Accept failed (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
sockets_[peers.size() - 1 - i] = peer_socket;
}
// Close the listening socket
shutdown(sock, 2);
close(sock);
if (peers_.size() == 0) {
return;
}
// Connect to the peers with smaller rank
for (int i = 0; i < rank_; i++) {
sockets_[i] = socket(AF_INET, SOCK_STREAM, 0);
if (sockets_[i] < 0) {
std::ostringstream msg;
msg << "Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) {
if (attempt > 0) {
int wait = (1 << (attempt - 1)) * CONN_WAIT;
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
}
success = connect(sockets_[i], peers[i].sockaddr(), peers[i].len);
if (success == 0) {
break;
}
}
if (success < 0) {
std::ostringstream msg;
msg << "Couldn't connect (rank: " << rank_ << " to: " << i
<< " error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Make the socket
socket_ = socket(PF_NDRV, SOCK_RAW, 0);
if (socket_ < 0) {
std::ostringstream msg;
msg << "Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Make the address to bind the socket
std::copy(interface.begin(), interface.end(), (char*)sockaddr_.snd_name);
sockaddr_.snd_family = PF_NDRV;
sockaddr_.snd_len = sizeof(sockaddr_);
if (bind(socket_, (sockaddr*)&sockaddr_, sizeof(sockaddr_)) < 0) {
std::ostringstream msg;
msg << "Couldn't bind socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Tell the kernel to filter and select for ETHER_TYPE
ndrv_protocol_desc desc;
ndrv_demux_desc demux_desc;
desc.version = NDRV_PROTOCOL_DESC_VERS;
desc.protocol_family = ETHER_TYPE;
desc.demux_count = 1;
desc.demux_list = &demux_desc;
demux_desc.type = NDRV_DEMUXTYPE_ETHERTYPE;
demux_desc.length = sizeof(uint16_t);
demux_desc.data.ether_type = ETHER_TYPE_NTOHS;
if (setsockopt(
socket_, SOL_NDRVPROTO, NDRV_SETDMXSPEC, &desc, sizeof(desc)) < 0) {
std::ostringstream msg;
msg << "Couldn't set socket option (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
}
~GroupImpl() {
if (global_) {
for (int sock : sockets_) {
shutdown(sock, 2);
close(sock);
}
if (global_ && socket_ > 0) {
close(socket_);
}
}
@ -279,32 +232,57 @@ struct GroupImpl {
}
int size() {
return std::max(sockets_.size(), 1ul);
return std::max(peers_.size(), 1ul);
}
void send_packet(const char* buf, size_t len, int dst) {
char packet[1500];
peers_[dst].to_buffer(packet);
peers_[rank_].to_buffer(packet + sizeof(mac_address));
memcpy(packet + 2 * sizeof(mac_address), &ETHER_TYPE_NTOHS, sizeof(ETHER_TYPE_NTOHS));
constexpr int header_len = 2 * sizeof(mac_address) + 2;
memcpy(packet + header_len, buf, len);
int r = sendto(
socket_,
packet,
len + header_len,
0,
(sockaddr*)&sockaddr_,
sizeof(sockaddr_));
if (r < 0) {
std::ostringstream msg;
msg << "Send failed (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
}
void recv_packet(char* buf, size_t len, int src) {
char packet[1500];
constexpr int header_len = 2 * sizeof(mac_address) + 2;
int r = ::recv(socket_, packet, len + header_len, 0);
if (r < 0) {
std::ostringstream msg;
msg << "Send failed (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
memcpy(buf, packet + header_len, len);
}
void send(const char* buf, size_t len, int dst) {
while (len > 0) {
ssize_t r = ::send(sockets_[dst], buf, len, 0);
if (r <= 0) {
std::ostringstream msg;
msg << "Send of " << len << " bytes failed (errno: " << errno << ")";
throw std::runtime_error(msg.str());
}
buf += r;
len -= r;
size_t l = std::min(len, PACKET_SIZE);
send_packet(buf, l, dst);
buf += l;
len -= l;
}
}
void recv(char* buf, size_t len, int src) {
while (len > 0) {
ssize_t r = ::recv(sockets_[src], buf, len, 0);
if (r <= 0) {
std::ostringstream msg;
msg << "Recv of " << len << " bytes failed (errno: " << errno << ")";
throw std::runtime_error(msg.str());
}
buf += r;
len -= r;
size_t l = std::min(len, PACKET_SIZE);
recv_packet(buf, l, src);
buf += l;
len -= l;
}
}
@ -364,7 +342,9 @@ struct GroupImpl {
int rank_;
bool global_;
ThreadPool pool_;
std::vector<int> sockets_;
std::vector<mac_address> peers_;
sockaddr_ndrv sockaddr_;
int socket_;
};
} // namespace
@ -389,7 +369,7 @@ Group init(bool strict /* = false */) {
static std::shared_ptr<GroupImpl> global_group = nullptr;
if (global_group == nullptr) {
auto peers = load_peers();
auto [iface, peers] = parse_config();
int rank = 0;
if (const char* rank_buf = std::getenv("MLX_RANK")) {
rank = std::atoi(rank_buf);
@ -399,7 +379,8 @@ Group init(bool strict /* = false */) {
throw std::runtime_error("Can't initialize distributed");
}
}
global_group = std::make_shared<GroupImpl>(std::move(peers), rank, true);
global_group =
std::make_shared<GroupImpl>(iface, std::move(peers), rank, true);
}
return Group(global_group);
}