mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-02 07:01:28 +08:00
TCP socket distributed
This commit is contained in:
parent
97a9561e34
commit
a9746587f1
@ -5,14 +5,77 @@
|
||||
#include <netdb.h>
|
||||
#include <sys/socket.h>
|
||||
#include <unistd.h>
|
||||
#include <chrono>
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/io/threadpool.h"
|
||||
|
||||
#define SWITCH_TYPE(x, ...) \
|
||||
switch ((x).dtype()) { \
|
||||
case bool_: { \
|
||||
using T = bool; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case int8: { \
|
||||
using T = int8_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case int16: { \
|
||||
using T = int16_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case int32: { \
|
||||
using T = int32_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case int64: { \
|
||||
using T = int64_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case uint8: { \
|
||||
using T = uint8_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case uint16: { \
|
||||
using T = uint16_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case uint32: { \
|
||||
using T = uint32_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case uint64: { \
|
||||
using T = uint64_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case bfloat16: { \
|
||||
using T = bfloat16_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case float16: { \
|
||||
using T = float16_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case float32: { \
|
||||
using T = float; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case complex64: { \
|
||||
using T = complex64_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
}
|
||||
|
||||
constexpr const size_t PACKET_SIZE = 262144;
|
||||
constexpr const int CONN_ATTEMPTS = 5;
|
||||
constexpr const int CONN_WAIT = 1000;
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
@ -30,46 +93,8 @@ void sum_inplace(const T* input, T* output, size_t N) {
|
||||
}
|
||||
|
||||
void sum_inplace(const array& input, array& output) {
|
||||
switch (input.dtype()) {
|
||||
case bool_:
|
||||
return sum_inplace(input.data<bool>(), output.data<bool>(), input.size());
|
||||
case int8:
|
||||
return sum_inplace(
|
||||
input.data<int8_t>(), output.data<int8_t>(), input.size());
|
||||
case uint8:
|
||||
return sum_inplace(
|
||||
input.data<uint8_t>(), output.data<uint8_t>(), input.size());
|
||||
case int16:
|
||||
return sum_inplace(
|
||||
input.data<int16_t>(), output.data<int16_t>(), input.size());
|
||||
case uint16:
|
||||
return sum_inplace(
|
||||
input.data<uint16_t>(), output.data<uint16_t>(), input.size());
|
||||
case int32:
|
||||
return sum_inplace(
|
||||
input.data<int32_t>(), output.data<int32_t>(), input.size());
|
||||
case uint32:
|
||||
return sum_inplace(
|
||||
input.data<uint32_t>(), output.data<uint32_t>(), input.size());
|
||||
case int64:
|
||||
return sum_inplace(
|
||||
input.data<int64_t>(), output.data<int64_t>(), input.size());
|
||||
case uint64:
|
||||
return sum_inplace(
|
||||
input.data<uint64_t>(), output.data<uint64_t>(), input.size());
|
||||
case float16:
|
||||
return sum_inplace(
|
||||
input.data<float16_t>(), output.data<float16_t>(), input.size());
|
||||
case bfloat16:
|
||||
return sum_inplace(
|
||||
input.data<bfloat16_t>(), output.data<bfloat16_t>(), input.size());
|
||||
case float32:
|
||||
return sum_inplace(
|
||||
input.data<float>(), output.data<float>(), input.size());
|
||||
case complex64:
|
||||
return sum_inplace(
|
||||
input.data<complex64_t>(), output.data<complex64_t>(), input.size());
|
||||
}
|
||||
SWITCH_TYPE(
|
||||
input, sum_inplace(input.data<T>(), output.data<T>(), input.size()));
|
||||
}
|
||||
|
||||
array ensure_row_contiguous(const array& arr) {
|
||||
@ -95,7 +120,7 @@ address_t parse_address(std::string ip, std::string port) {
|
||||
struct addrinfo hints, *res;
|
||||
memset(&hints, 0, sizeof(hints));
|
||||
hints.ai_family = AF_UNSPEC;
|
||||
hints.ai_socktype = SOCK_DGRAM;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
|
||||
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
|
||||
if (status != 0) {
|
||||
@ -134,30 +159,118 @@ std::vector<address_t> load_peers() {
|
||||
|
||||
struct GroupImpl {
|
||||
GroupImpl(std::vector<address_t> peers, int rank, bool global)
|
||||
: peers_(std::move(peers)), rank_(rank), global_(global) {
|
||||
if (rank_ > 0 && rank_ >= peers_.size()) {
|
||||
: rank_(rank), global_(global), pool_(4), sockets_(peers.size(), -1) {
|
||||
if (rank_ > 0 && rank_ >= peers.size()) {
|
||||
throw std::runtime_error(
|
||||
"Rank cannot be larger than the size of the group");
|
||||
}
|
||||
if (global_ && rank_ < peers_.size()) {
|
||||
socket_fd_ = socket(AF_INET, SOCK_DGRAM, 0);
|
||||
if (socket_fd_ < 0) {
|
||||
|
||||
int success;
|
||||
|
||||
// If we are expecting anyone to connect to us
|
||||
if (rank_ < peers.size() - 1) {
|
||||
// Create the socket to wait for connections from the peers
|
||||
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sock < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Couldn't create socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
int success =
|
||||
bind(socket_fd_, peers_[rank_].sockaddr(), peers_[rank_].len);
|
||||
|
||||
// Make sure we can launch immediately after shutdown by setting the
|
||||
// reuseaddr option so that we don't get address already in use errors
|
||||
int enable = 1;
|
||||
success =
|
||||
setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "Couldn't enable reuseaddr (rank: " << rank_
|
||||
<< " error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
success =
|
||||
setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "Couldn't enable reuseport (rank: " << rank_
|
||||
<< " error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Bind it to the port
|
||||
success = bind(sock, peers[rank_].sockaddr(), peers[rank_].len);
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "Couldn't bind socket (rank: " << rank_ << " error: " << errno
|
||||
<< ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Wait for connections
|
||||
success = listen(sock, 0);
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "Couldn't listen (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
for (int i = 0; i < peers.size() - rank_ - 1; i++) {
|
||||
int peer_socket = accept(sock, nullptr, nullptr);
|
||||
if (peer_socket < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "Accept failed (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
sockets_[peers.size() - 1 - i] = peer_socket;
|
||||
}
|
||||
|
||||
// Close the listening socket
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
}
|
||||
|
||||
// Connect to the peers with smaller rank
|
||||
for (int i = 0; i < rank_; i++) {
|
||||
sockets_[i] = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sockets_[i] < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Couldn't create socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) {
|
||||
if (attempt > 0) {
|
||||
int wait = (1 << (attempt - 1)) * CONN_WAIT;
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
|
||||
}
|
||||
success = connect(sockets_[i], peers[i].sockaddr(), peers[i].len);
|
||||
if (success == 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (success < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Couldn't bind socket (error: " << errno << ")";
|
||||
msg << "Couldn't connect (rank: " << rank_ << " to: " << i
|
||||
<< " error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
~GroupImpl() {
|
||||
if (global_) {
|
||||
close(socket_fd_);
|
||||
for (int sock : sockets_) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -166,43 +279,92 @@ struct GroupImpl {
|
||||
}
|
||||
|
||||
int size() {
|
||||
return std::max(peers_.size(), 1ul);
|
||||
return std::max(sockets_.size(), 1ul);
|
||||
}
|
||||
|
||||
void send(const char* buf, size_t len, int dst) {
|
||||
while (len > 0) {
|
||||
size_t l = std::min(len, 8192ul);
|
||||
ssize_t r = sendto(
|
||||
socket_fd_, buf, l, 0, peers_[dst].sockaddr(), peers_[dst].len);
|
||||
ssize_t r = ::send(sockets_[dst], buf, len, 0);
|
||||
if (r <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Send of " << l << " bytes failed (errno: " << errno << ")";
|
||||
msg << "Send of " << len << " bytes failed (errno: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
len -= l;
|
||||
buf += l;
|
||||
}
|
||||
}
|
||||
|
||||
void recv(char* buf, size_t len, int src) {
|
||||
sockaddr_storage addr;
|
||||
socklen_t addr_len;
|
||||
while (len != 0) {
|
||||
ssize_t r =
|
||||
recvfrom(socket_fd_, buf, len, 0, (struct sockaddr*)&addr, &addr_len);
|
||||
if (r <= 0) {
|
||||
throw std::runtime_error("Recv failed");
|
||||
}
|
||||
buf += r;
|
||||
len -= r;
|
||||
}
|
||||
}
|
||||
|
||||
void recv(char* buf, size_t len, int src) {
|
||||
while (len > 0) {
|
||||
ssize_t r = ::recv(sockets_[src], buf, len, 0);
|
||||
if (r <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Recv of " << len << " bytes failed (errno: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
buf += r;
|
||||
len -= r;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void send_recv_sum(char* buf, size_t len, int peer) {
|
||||
char recv_buffer[2 * PACKET_SIZE];
|
||||
char* recv_buffers[2];
|
||||
recv_buffers[0] = recv_buffer;
|
||||
recv_buffers[1] = recv_buffer + PACKET_SIZE;
|
||||
std::future<void> sent, received;
|
||||
size_t n_blocks = (len + PACKET_SIZE - 1) / PACKET_SIZE;
|
||||
|
||||
for (size_t b = 0; b < n_blocks; b++) {
|
||||
if (b > 0) {
|
||||
sent.wait();
|
||||
received.wait();
|
||||
}
|
||||
size_t l = std::min(len - b * PACKET_SIZE, PACKET_SIZE);
|
||||
if (rank_ < peer) {
|
||||
sent = send_async(buf + b * PACKET_SIZE, l, peer);
|
||||
received = recv_async(recv_buffers[b % 2], l, peer);
|
||||
} else {
|
||||
received = recv_async(recv_buffers[b % 2], l, peer);
|
||||
sent = send_async(buf + b * PACKET_SIZE, l, peer);
|
||||
}
|
||||
if (b > 0) {
|
||||
sum_inplace(
|
||||
(const T*)recv_buffers[(b - 1) % 2],
|
||||
(T*)(buf + (b - 1) * PACKET_SIZE),
|
||||
PACKET_SIZE / sizeof(T));
|
||||
}
|
||||
}
|
||||
sent.wait();
|
||||
received.wait();
|
||||
size_t l = std::min(len - (n_blocks - 1) * PACKET_SIZE, PACKET_SIZE);
|
||||
sum_inplace(
|
||||
(const T*)recv_buffers[(n_blocks - 1) % 2],
|
||||
(T*)(buf + (n_blocks - 1) * PACKET_SIZE),
|
||||
l / sizeof(T));
|
||||
}
|
||||
|
||||
void send_recv_sum(array& out, int peer) {
|
||||
SWITCH_TYPE(out, send_recv_sum<T>(out.data<char>(), out.nbytes(), peer));
|
||||
}
|
||||
|
||||
std::future<void> send_async(const char* buf, size_t len, int dst) {
|
||||
return pool_.enqueue(
|
||||
[this, buf, len, dst]() { this->send(buf, len, dst); });
|
||||
}
|
||||
|
||||
std::future<void> recv_async(char* buf, size_t len, int src) {
|
||||
return pool_.enqueue(
|
||||
[this, buf, len, src]() { this->recv(buf, len, src); });
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<address_t> peers_;
|
||||
int rank_;
|
||||
bool global_;
|
||||
int socket_fd_;
|
||||
ThreadPool pool_;
|
||||
std::vector<int> sockets_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@ -251,57 +413,84 @@ Stream communication_stream() {
|
||||
|
||||
void all_sum(Group group_, const array& input_, array& output) {
|
||||
auto group = std::static_pointer_cast<GroupImpl>(group_.raw_group());
|
||||
if (group->size() != 2) {
|
||||
throw std::runtime_error("Only pairwise communication supported for now");
|
||||
}
|
||||
array input = ensure_row_contiguous(input_);
|
||||
|
||||
// Donation not supported
|
||||
if (input.data<void>() == output.data<void>()) {
|
||||
array temp(
|
||||
allocator::malloc_or_wait(output.nbytes()),
|
||||
output.shape(),
|
||||
output.dtype());
|
||||
if (group->rank() == 0) {
|
||||
group->send(input.data<char>(), input.nbytes(), 1);
|
||||
group->recv(temp.data<char>(), output.nbytes(), 1);
|
||||
sum_inplace(temp, output);
|
||||
} else {
|
||||
group->recv(temp.data<char>(), output.nbytes(), 0);
|
||||
group->send(input.data<char>(), input.nbytes(), 0);
|
||||
sum_inplace(temp, output);
|
||||
}
|
||||
} else {
|
||||
if (group->rank() == 0) {
|
||||
group->send(input.data<char>(), input.nbytes(), 1);
|
||||
group->recv(output.data<char>(), output.nbytes(), 1);
|
||||
sum_inplace(input, output);
|
||||
} else {
|
||||
group->recv(output.data<char>(), output.nbytes(), 0);
|
||||
group->send(input.data<char>(), input.nbytes(), 0);
|
||||
sum_inplace(input, output);
|
||||
}
|
||||
int size = group->size();
|
||||
int rank = group->rank();
|
||||
|
||||
if ((size & (size - 1)) != 0) {
|
||||
throw std::runtime_error("Only powers of 2 are currently supported");
|
||||
}
|
||||
|
||||
// If not inplace all reduce then copy the input to the output first.
|
||||
if (input.data<void>() != output.data<void>()) {
|
||||
std::memcpy(output.data<char>(), input.data<char>(), input.nbytes());
|
||||
}
|
||||
|
||||
// Butterfly all reduce
|
||||
for (int distance = 1; distance <= size / 2; distance *= 2) {
|
||||
group->send_recv_sum(output, rank ^ distance);
|
||||
}
|
||||
}
|
||||
|
||||
void all_gather(Group group_, const array& input_, array& output) {
|
||||
auto group = std::static_pointer_cast<GroupImpl>(group_.raw_group());
|
||||
if (group->size() != 2) {
|
||||
throw std::runtime_error("Only pairwise communication supported for now");
|
||||
}
|
||||
array input = ensure_row_contiguous(input_);
|
||||
if (group->rank() == 0) {
|
||||
group->send(input.data<char>(), input.nbytes(), 1);
|
||||
group->recv(output.data<char>() + input.nbytes(), input.nbytes(), 1);
|
||||
std::memcpy(output.data<char>(), input.data<char>(), input.nbytes());
|
||||
} else {
|
||||
group->recv(output.data<char>(), input.nbytes(), 0);
|
||||
group->send(input.data<char>(), input.nbytes(), 0);
|
||||
std::memcpy(
|
||||
output.data<char>() + input.nbytes(),
|
||||
input.data<char>(),
|
||||
input.nbytes());
|
||||
std::future<void> sent;
|
||||
std::future<void> received;
|
||||
|
||||
int rank = group->rank();
|
||||
int size = group->size();
|
||||
|
||||
if ((size & (size - 1)) != 0) {
|
||||
throw std::runtime_error("Only powers of 2 are currently supported");
|
||||
}
|
||||
|
||||
// Butterfly all gather
|
||||
int peer = rank ^ 1;
|
||||
if (peer < rank) {
|
||||
received = group->recv_async(
|
||||
output.data<char>() + peer * input.nbytes(), input.nbytes(), peer);
|
||||
sent = group->send_async(input.data<char>(), input.nbytes(), peer);
|
||||
} else {
|
||||
sent = group->send_async(input.data<char>(), input.nbytes(), peer);
|
||||
received = group->recv_async(
|
||||
output.data<char>() + peer * input.nbytes(), input.nbytes(), peer);
|
||||
}
|
||||
std::memcpy(
|
||||
output.data<char>() + rank * input.nbytes(),
|
||||
input.data<char>(),
|
||||
input.nbytes());
|
||||
|
||||
for (int distance = 2; distance <= size / 2; distance *= 2) {
|
||||
sent.wait();
|
||||
received.wait();
|
||||
int peer = rank ^ distance;
|
||||
int their_offset = peer & ~(distance - 1);
|
||||
int our_offset = rank & ~(distance - 1);
|
||||
|
||||
if (peer < rank) {
|
||||
received = group->recv_async(
|
||||
output.data<char>() + their_offset * input.nbytes(),
|
||||
distance * input.nbytes(),
|
||||
peer);
|
||||
sent = group->send_async(
|
||||
output.data<char>() + our_offset * input.nbytes(),
|
||||
distance * input.nbytes(),
|
||||
peer);
|
||||
} else {
|
||||
sent = group->send_async(
|
||||
output.data<char>() + our_offset * input.nbytes(),
|
||||
distance * input.nbytes(),
|
||||
peer);
|
||||
received = group->recv_async(
|
||||
output.data<char>() + their_offset * input.nbytes(),
|
||||
distance * input.nbytes(),
|
||||
peer);
|
||||
}
|
||||
}
|
||||
sent.wait();
|
||||
received.wait();
|
||||
}
|
||||
|
||||
void send(Group group_, const array& input_, int dst) {
|
||||
|
Loading…
Reference in New Issue
Block a user