mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Raw sockets
This commit is contained in:
parent
2e267bd6a8
commit
3fe98bacc7
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include <arpa/inet.h>
|
#include <arpa/inet.h>
|
||||||
#include <json.hpp>
|
#include <json.hpp>
|
||||||
|
#include <net/ndrv.h>
|
||||||
#include <netdb.h>
|
#include <netdb.h>
|
||||||
#include <sys/socket.h>
|
#include <sys/socket.h>
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
@ -73,9 +74,9 @@
|
|||||||
} break; \
|
} break; \
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr const size_t PACKET_SIZE = 262144;
|
constexpr const size_t PACKET_SIZE = 1408;
|
||||||
constexpr const int CONN_ATTEMPTS = 5;
|
constexpr const uint16_t ETHER_TYPE = 32923;
|
||||||
constexpr const int CONN_WAIT = 1000;
|
constexpr const uint16_t ETHER_TYPE_NTOHS = ntohs(ETHER_TYPE);
|
||||||
|
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
|
|
||||||
@ -107,170 +108,122 @@ array ensure_row_contiguous(const array& arr) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct address_t {
|
struct mac_address {
|
||||||
sockaddr_storage addr;
|
uint8_t raw[6] = {0};
|
||||||
socklen_t len;
|
|
||||||
|
|
||||||
const sockaddr* sockaddr() {
|
mac_address(const std::string& address) {
|
||||||
return (struct sockaddr*)&addr;
|
auto hex_to_int = [](const char c) -> uint8_t {
|
||||||
|
if (c >= 'a' && c <= 'f') {
|
||||||
|
return c - 'a' + 10;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (c >= 'A' && c <= 'F') {
|
||||||
|
return c - 'A' + 10;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (c >= '0' && c <= '9') {
|
||||||
|
return c - '0';
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
int idx = 0;
|
||||||
|
int cnt = 0;
|
||||||
|
for (const auto c : address) {
|
||||||
|
if (c == ':') {
|
||||||
|
idx += 1;
|
||||||
|
cnt = 0;
|
||||||
|
if (idx >= 6) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
raw[idx] <<= 4 * cnt;
|
||||||
|
raw[idx] += hex_to_int(c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void to_buffer(char* buf) {
|
||||||
|
for (int i = 0; i < 6; i++) {
|
||||||
|
buf[i] = ((char*)raw)[i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
address_t parse_address(std::string ip, std::string port) {
|
std::pair<std::string, std::vector<mac_address>> parse_config() {
|
||||||
struct addrinfo hints, *res;
|
std::vector<mac_address> peers;
|
||||||
memset(&hints, 0, sizeof(hints));
|
|
||||||
hints.ai_family = AF_UNSPEC;
|
|
||||||
hints.ai_socktype = SOCK_STREAM;
|
|
||||||
|
|
||||||
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
|
|
||||||
if (status != 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "Can't parse peer address " << ip << ":" << port;
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
address_t result;
|
|
||||||
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
|
|
||||||
result.len = res->ai_addrlen;
|
|
||||||
freeaddrinfo(res);
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<address_t> load_peers() {
|
|
||||||
std::vector<address_t> peers;
|
|
||||||
std::ifstream f;
|
std::ifstream f;
|
||||||
|
|
||||||
if (const char* hostfile_buf = std::getenv("MLX_HOSTFILE")) {
|
if (const char* hostfile_buf = std::getenv("MLX_HOSTFILE")) {
|
||||||
f.open(hostfile_buf);
|
f.open(hostfile_buf);
|
||||||
} else {
|
} else {
|
||||||
return peers;
|
return {"lo0", peers};
|
||||||
}
|
}
|
||||||
|
|
||||||
json hosts = json::parse(f);
|
json config = json::parse(f);
|
||||||
for (auto& h : hosts) {
|
for (auto& h : config["peers"]) {
|
||||||
peers.push_back(std::move(parse_address(
|
peers.emplace_back(h.get<std::string>());
|
||||||
h["ip"].template get<std::string>(),
|
|
||||||
h["port"].template get<std::string>())));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return peers;
|
return {config["interface"].get<std::string>(), peers};
|
||||||
}
|
}
|
||||||
|
|
||||||
struct GroupImpl {
|
struct GroupImpl {
|
||||||
GroupImpl(std::vector<address_t> peers, int rank, bool global)
|
GroupImpl(
|
||||||
: rank_(rank), global_(global), pool_(4), sockets_(peers.size(), -1) {
|
const std::string& interface,
|
||||||
if (rank_ > 0 && rank_ >= peers.size()) {
|
std::vector<mac_address> peers,
|
||||||
|
int rank,
|
||||||
|
bool global)
|
||||||
|
: rank_(rank), global_(global), pool_(1), peers_(std::move(peers)) {
|
||||||
|
if (rank_ > 0 && rank_ >= peers_.size()) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"Rank cannot be larger than the size of the group");
|
"Rank cannot be larger than the size of the group");
|
||||||
}
|
}
|
||||||
|
|
||||||
int success;
|
if (peers_.size() == 0) {
|
||||||
|
return;
|
||||||
// If we are expecting anyone to connect to us
|
|
||||||
if (rank_ + 1 < peers.size()) {
|
|
||||||
// Create the socket to wait for connections from the peers
|
|
||||||
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
|
||||||
if (sock < 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "Couldn't create socket (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure we can launch immediately after shutdown by setting the
|
|
||||||
// reuseaddr option so that we don't get address already in use errors
|
|
||||||
int enable = 1;
|
|
||||||
success =
|
|
||||||
setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
|
|
||||||
if (success < 0) {
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "Couldn't enable reuseaddr (rank: " << rank_
|
|
||||||
<< " error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
success =
|
|
||||||
setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
|
|
||||||
if (success < 0) {
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "Couldn't enable reuseport (rank: " << rank_
|
|
||||||
<< " error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bind it to the port
|
|
||||||
success = bind(sock, peers[rank_].sockaddr(), peers[rank_].len);
|
|
||||||
if (success < 0) {
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "Couldn't bind socket (rank: " << rank_ << " error: " << errno
|
|
||||||
<< ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for connections
|
|
||||||
success = listen(sock, 0);
|
|
||||||
if (success < 0) {
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "Couldn't listen (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
for (int i = 0; i < peers.size() - rank_ - 1; i++) {
|
|
||||||
int peer_socket = accept(sock, nullptr, nullptr);
|
|
||||||
if (peer_socket < 0) {
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "Accept failed (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
sockets_[peers.size() - 1 - i] = peer_socket;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close the listening socket
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect to the peers with smaller rank
|
// Make the socket
|
||||||
for (int i = 0; i < rank_; i++) {
|
socket_ = socket(PF_NDRV, SOCK_RAW, 0);
|
||||||
sockets_[i] = socket(AF_INET, SOCK_STREAM, 0);
|
if (socket_ < 0) {
|
||||||
if (sockets_[i] < 0) {
|
std::ostringstream msg;
|
||||||
std::ostringstream msg;
|
msg << "Couldn't create socket (error: " << errno << ")";
|
||||||
msg << "Couldn't create socket (error: " << errno << ")";
|
throw std::runtime_error(msg.str());
|
||||||
throw std::runtime_error(msg.str());
|
}
|
||||||
}
|
|
||||||
for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) {
|
// Make the address to bind the socket
|
||||||
if (attempt > 0) {
|
std::copy(interface.begin(), interface.end(), (char*)sockaddr_.snd_name);
|
||||||
int wait = (1 << (attempt - 1)) * CONN_WAIT;
|
sockaddr_.snd_family = PF_NDRV;
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
|
sockaddr_.snd_len = sizeof(sockaddr_);
|
||||||
}
|
if (bind(socket_, (sockaddr*)&sockaddr_, sizeof(sockaddr_)) < 0) {
|
||||||
success = connect(sockets_[i], peers[i].sockaddr(), peers[i].len);
|
std::ostringstream msg;
|
||||||
if (success == 0) {
|
msg << "Couldn't bind socket (error: " << errno << ")";
|
||||||
break;
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if (success < 0) {
|
// Tell the kernel to filter and select for ETHER_TYPE
|
||||||
std::ostringstream msg;
|
ndrv_protocol_desc desc;
|
||||||
msg << "Couldn't connect (rank: " << rank_ << " to: " << i
|
ndrv_demux_desc demux_desc;
|
||||||
<< " error: " << errno << ")";
|
desc.version = NDRV_PROTOCOL_DESC_VERS;
|
||||||
throw std::runtime_error(msg.str());
|
desc.protocol_family = ETHER_TYPE;
|
||||||
}
|
desc.demux_count = 1;
|
||||||
|
desc.demux_list = &demux_desc;
|
||||||
|
demux_desc.type = NDRV_DEMUXTYPE_ETHERTYPE;
|
||||||
|
demux_desc.length = sizeof(uint16_t);
|
||||||
|
demux_desc.data.ether_type = ETHER_TYPE_NTOHS;
|
||||||
|
if (setsockopt(
|
||||||
|
socket_, SOL_NDRVPROTO, NDRV_SETDMXSPEC, &desc, sizeof(desc)) < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Couldn't set socket option (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
~GroupImpl() {
|
~GroupImpl() {
|
||||||
if (global_) {
|
if (global_ && socket_ > 0) {
|
||||||
for (int sock : sockets_) {
|
close(socket_);
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -279,32 +232,57 @@ struct GroupImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int size() {
|
int size() {
|
||||||
return std::max(sockets_.size(), 1ul);
|
return std::max(peers_.size(), 1ul);
|
||||||
|
}
|
||||||
|
|
||||||
|
void send_packet(const char* buf, size_t len, int dst) {
|
||||||
|
char packet[1500];
|
||||||
|
peers_[dst].to_buffer(packet);
|
||||||
|
peers_[rank_].to_buffer(packet + sizeof(mac_address));
|
||||||
|
memcpy(packet + 2 * sizeof(mac_address), ÐER_TYPE_NTOHS, sizeof(ETHER_TYPE_NTOHS));
|
||||||
|
constexpr int header_len = 2 * sizeof(mac_address) + 2;
|
||||||
|
memcpy(packet + header_len, buf, len);
|
||||||
|
int r = sendto(
|
||||||
|
socket_,
|
||||||
|
packet,
|
||||||
|
len + header_len,
|
||||||
|
0,
|
||||||
|
(sockaddr*)&sockaddr_,
|
||||||
|
sizeof(sockaddr_));
|
||||||
|
if (r < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Send failed (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void recv_packet(char* buf, size_t len, int src) {
|
||||||
|
char packet[1500];
|
||||||
|
constexpr int header_len = 2 * sizeof(mac_address) + 2;
|
||||||
|
int r = ::recv(socket_, packet, len + header_len, 0);
|
||||||
|
if (r < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Send failed (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
memcpy(buf, packet + header_len, len);
|
||||||
}
|
}
|
||||||
|
|
||||||
void send(const char* buf, size_t len, int dst) {
|
void send(const char* buf, size_t len, int dst) {
|
||||||
while (len > 0) {
|
while (len > 0) {
|
||||||
ssize_t r = ::send(sockets_[dst], buf, len, 0);
|
size_t l = std::min(len, PACKET_SIZE);
|
||||||
if (r <= 0) {
|
send_packet(buf, l, dst);
|
||||||
std::ostringstream msg;
|
buf += l;
|
||||||
msg << "Send of " << len << " bytes failed (errno: " << errno << ")";
|
len -= l;
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
buf += r;
|
|
||||||
len -= r;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void recv(char* buf, size_t len, int src) {
|
void recv(char* buf, size_t len, int src) {
|
||||||
while (len > 0) {
|
while (len > 0) {
|
||||||
ssize_t r = ::recv(sockets_[src], buf, len, 0);
|
size_t l = std::min(len, PACKET_SIZE);
|
||||||
if (r <= 0) {
|
recv_packet(buf, l, src);
|
||||||
std::ostringstream msg;
|
buf += l;
|
||||||
msg << "Recv of " << len << " bytes failed (errno: " << errno << ")";
|
len -= l;
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
buf += r;
|
|
||||||
len -= r;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -364,7 +342,9 @@ struct GroupImpl {
|
|||||||
int rank_;
|
int rank_;
|
||||||
bool global_;
|
bool global_;
|
||||||
ThreadPool pool_;
|
ThreadPool pool_;
|
||||||
std::vector<int> sockets_;
|
std::vector<mac_address> peers_;
|
||||||
|
sockaddr_ndrv sockaddr_;
|
||||||
|
int socket_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -389,7 +369,7 @@ Group init(bool strict /* = false */) {
|
|||||||
static std::shared_ptr<GroupImpl> global_group = nullptr;
|
static std::shared_ptr<GroupImpl> global_group = nullptr;
|
||||||
|
|
||||||
if (global_group == nullptr) {
|
if (global_group == nullptr) {
|
||||||
auto peers = load_peers();
|
auto [iface, peers] = parse_config();
|
||||||
int rank = 0;
|
int rank = 0;
|
||||||
if (const char* rank_buf = std::getenv("MLX_RANK")) {
|
if (const char* rank_buf = std::getenv("MLX_RANK")) {
|
||||||
rank = std::atoi(rank_buf);
|
rank = std::atoi(rank_buf);
|
||||||
@ -399,7 +379,8 @@ Group init(bool strict /* = false */) {
|
|||||||
throw std::runtime_error("Can't initialize distributed");
|
throw std::runtime_error("Can't initialize distributed");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
global_group = std::make_shared<GroupImpl>(std::move(peers), rank, true);
|
global_group =
|
||||||
|
std::make_shared<GroupImpl>(iface, std::move(peers), rank, true);
|
||||||
}
|
}
|
||||||
return Group(global_group);
|
return Group(global_group);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user