diff --git a/mlx/distributed/sockets/sockets.cpp b/mlx/distributed/sockets/sockets.cpp index 753825f34..806893e2e 100644 --- a/mlx/distributed/sockets/sockets.cpp +++ b/mlx/distributed/sockets/sockets.cpp @@ -5,14 +5,77 @@ #include #include #include +#include #include #include #include #include +#include #include "mlx/backend/common/copy.h" #include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed_impl.h" +#include "mlx/io/threadpool.h" + +#define SWITCH_TYPE(x, ...) \ + switch ((x).dtype()) { \ + case bool_: { \ + using T = bool; \ + __VA_ARGS__; \ + } break; \ + case int8: { \ + using T = int8_t; \ + __VA_ARGS__; \ + } break; \ + case int16: { \ + using T = int16_t; \ + __VA_ARGS__; \ + } break; \ + case int32: { \ + using T = int32_t; \ + __VA_ARGS__; \ + } break; \ + case int64: { \ + using T = int64_t; \ + __VA_ARGS__; \ + } break; \ + case uint8: { \ + using T = uint8_t; \ + __VA_ARGS__; \ + } break; \ + case uint16: { \ + using T = uint16_t; \ + __VA_ARGS__; \ + } break; \ + case uint32: { \ + using T = uint32_t; \ + __VA_ARGS__; \ + } break; \ + case uint64: { \ + using T = uint64_t; \ + __VA_ARGS__; \ + } break; \ + case bfloat16: { \ + using T = bfloat16_t; \ + __VA_ARGS__; \ + } break; \ + case float16: { \ + using T = float16_t; \ + __VA_ARGS__; \ + } break; \ + case float32: { \ + using T = float; \ + __VA_ARGS__; \ + } break; \ + case complex64: { \ + using T = complex64_t; \ + __VA_ARGS__; \ + } break; \ + } + +constexpr const size_t PACKET_SIZE = 262144; +constexpr const int CONN_ATTEMPTS = 5; +constexpr const int CONN_WAIT = 1000; using json = nlohmann::json; @@ -30,46 +93,8 @@ void sum_inplace(const T* input, T* output, size_t N) { } void sum_inplace(const array& input, array& output) { - switch (input.dtype()) { - case bool_: - return sum_inplace(input.data(), output.data(), input.size()); - case int8: - return sum_inplace( - input.data(), output.data(), input.size()); - case uint8: - return sum_inplace( - input.data(), output.data(), input.size()); - case int16: - return sum_inplace( - input.data(), output.data(), input.size()); - case uint16: - return sum_inplace( - input.data(), output.data(), input.size()); - case int32: - return sum_inplace( - input.data(), output.data(), input.size()); - case uint32: - return sum_inplace( - input.data(), output.data(), input.size()); - case int64: - return sum_inplace( - input.data(), output.data(), input.size()); - case uint64: - return sum_inplace( - input.data(), output.data(), input.size()); - case float16: - return sum_inplace( - input.data(), output.data(), input.size()); - case bfloat16: - return sum_inplace( - input.data(), output.data(), input.size()); - case float32: - return sum_inplace( - input.data(), output.data(), input.size()); - case complex64: - return sum_inplace( - input.data(), output.data(), input.size()); - } + SWITCH_TYPE( + input, sum_inplace(input.data(), output.data(), input.size())); } array ensure_row_contiguous(const array& arr) { @@ -95,7 +120,7 @@ address_t parse_address(std::string ip, std::string port) { struct addrinfo hints, *res; memset(&hints, 0, sizeof(hints)); hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_DGRAM; + hints.ai_socktype = SOCK_STREAM; int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res); if (status != 0) { @@ -134,30 +159,118 @@ std::vector load_peers() { struct GroupImpl { GroupImpl(std::vector peers, int rank, bool global) - : peers_(std::move(peers)), rank_(rank), global_(global) { - if (rank_ > 0 && rank_ >= peers_.size()) { + : rank_(rank), global_(global), pool_(4), sockets_(peers.size(), -1) { + if (rank_ > 0 && rank_ >= peers.size()) { throw std::runtime_error( "Rank cannot be larger than the size of the group"); } - if (global_ && rank_ < peers_.size()) { - socket_fd_ = socket(AF_INET, SOCK_DGRAM, 0); - if (socket_fd_ < 0) { + + int success; + + // If we are expecting anyone to connect to us + if (rank_ < peers.size() - 1) { + // Create the socket to wait for connections from the peers + int sock = socket(AF_INET, SOCK_STREAM, 0); + if (sock < 0) { std::ostringstream msg; msg << "Couldn't create socket (error: " << errno << ")"; throw std::runtime_error(msg.str()); } - int success = - bind(socket_fd_, peers_[rank_].sockaddr(), peers_[rank_].len); + + // Make sure we can launch immediately after shutdown by setting the + // reuseaddr option so that we don't get address already in use errors + int enable = 1; + success = + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)); + if (success < 0) { + shutdown(sock, 2); + close(sock); + std::ostringstream msg; + msg << "Couldn't enable reuseaddr (rank: " << rank_ + << " error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + success = + setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int)); + if (success < 0) { + shutdown(sock, 2); + close(sock); + std::ostringstream msg; + msg << "Couldn't enable reuseport (rank: " << rank_ + << " error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + + // Bind it to the port + success = bind(sock, peers[rank_].sockaddr(), peers[rank_].len); + if (success < 0) { + shutdown(sock, 2); + close(sock); + std::ostringstream msg; + msg << "Couldn't bind socket (rank: " << rank_ << " error: " << errno + << ")"; + throw std::runtime_error(msg.str()); + } + + // Wait for connections + success = listen(sock, 0); + if (success < 0) { + shutdown(sock, 2); + close(sock); + std::ostringstream msg; + msg << "Couldn't listen (error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + for (int i = 0; i < peers.size() - rank_ - 1; i++) { + int peer_socket = accept(sock, nullptr, nullptr); + if (peer_socket < 0) { + shutdown(sock, 2); + close(sock); + std::ostringstream msg; + msg << "Accept failed (error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + sockets_[peers.size() - 1 - i] = peer_socket; + } + + // Close the listening socket + shutdown(sock, 2); + close(sock); + } + + // Connect to the peers with smaller rank + for (int i = 0; i < rank_; i++) { + sockets_[i] = socket(AF_INET, SOCK_STREAM, 0); + if (sockets_[i] < 0) { + std::ostringstream msg; + msg << "Couldn't create socket (error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) { + if (attempt > 0) { + int wait = (1 << (attempt - 1)) * CONN_WAIT; + std::this_thread::sleep_for(std::chrono::milliseconds(wait)); + } + success = connect(sockets_[i], peers[i].sockaddr(), peers[i].len); + if (success == 0) { + break; + } + } if (success < 0) { std::ostringstream msg; - msg << "Couldn't bind socket (error: " << errno << ")"; + msg << "Couldn't connect (rank: " << rank_ << " to: " << i + << " error: " << errno << ")"; throw std::runtime_error(msg.str()); } } } + ~GroupImpl() { if (global_) { - close(socket_fd_); + for (int sock : sockets_) { + shutdown(sock, 2); + close(sock); + } } } @@ -166,43 +279,92 @@ struct GroupImpl { } int size() { - return std::max(peers_.size(), 1ul); + return std::max(sockets_.size(), 1ul); } void send(const char* buf, size_t len, int dst) { while (len > 0) { - size_t l = std::min(len, 8192ul); - ssize_t r = sendto( - socket_fd_, buf, l, 0, peers_[dst].sockaddr(), peers_[dst].len); + ssize_t r = ::send(sockets_[dst], buf, len, 0); if (r <= 0) { std::ostringstream msg; - msg << "Send of " << l << " bytes failed (errno: " << errno << ")"; + msg << "Send of " << len << " bytes failed (errno: " << errno << ")"; throw std::runtime_error(msg.str()); } - len -= l; - buf += l; - } - } - - void recv(char* buf, size_t len, int src) { - sockaddr_storage addr; - socklen_t addr_len; - while (len != 0) { - ssize_t r = - recvfrom(socket_fd_, buf, len, 0, (struct sockaddr*)&addr, &addr_len); - if (r <= 0) { - throw std::runtime_error("Recv failed"); - } buf += r; len -= r; } } + void recv(char* buf, size_t len, int src) { + while (len > 0) { + ssize_t r = ::recv(sockets_[src], buf, len, 0); + if (r <= 0) { + std::ostringstream msg; + msg << "Recv of " << len << " bytes failed (errno: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + buf += r; + len -= r; + } + } + + template + void send_recv_sum(char* buf, size_t len, int peer) { + char recv_buffer[2 * PACKET_SIZE]; + char* recv_buffers[2]; + recv_buffers[0] = recv_buffer; + recv_buffers[1] = recv_buffer + PACKET_SIZE; + std::future sent, received; + size_t n_blocks = (len + PACKET_SIZE - 1) / PACKET_SIZE; + + for (size_t b = 0; b < n_blocks; b++) { + if (b > 0) { + sent.wait(); + received.wait(); + } + size_t l = std::min(len - b * PACKET_SIZE, PACKET_SIZE); + if (rank_ < peer) { + sent = send_async(buf + b * PACKET_SIZE, l, peer); + received = recv_async(recv_buffers[b % 2], l, peer); + } else { + received = recv_async(recv_buffers[b % 2], l, peer); + sent = send_async(buf + b * PACKET_SIZE, l, peer); + } + if (b > 0) { + sum_inplace( + (const T*)recv_buffers[(b - 1) % 2], + (T*)(buf + (b - 1) * PACKET_SIZE), + PACKET_SIZE / sizeof(T)); + } + } + sent.wait(); + received.wait(); + size_t l = std::min(len - (n_blocks - 1) * PACKET_SIZE, PACKET_SIZE); + sum_inplace( + (const T*)recv_buffers[(n_blocks - 1) % 2], + (T*)(buf + (n_blocks - 1) * PACKET_SIZE), + l / sizeof(T)); + } + + void send_recv_sum(array& out, int peer) { + SWITCH_TYPE(out, send_recv_sum(out.data(), out.nbytes(), peer)); + } + + std::future send_async(const char* buf, size_t len, int dst) { + return pool_.enqueue( + [this, buf, len, dst]() { this->send(buf, len, dst); }); + } + + std::future recv_async(char* buf, size_t len, int src) { + return pool_.enqueue( + [this, buf, len, src]() { this->recv(buf, len, src); }); + } + private: - std::vector peers_; int rank_; bool global_; - int socket_fd_; + ThreadPool pool_; + std::vector sockets_; }; } // namespace @@ -251,57 +413,84 @@ Stream communication_stream() { void all_sum(Group group_, const array& input_, array& output) { auto group = std::static_pointer_cast(group_.raw_group()); - if (group->size() != 2) { - throw std::runtime_error("Only pairwise communication supported for now"); - } array input = ensure_row_contiguous(input_); - // Donation not supported - if (input.data() == output.data()) { - array temp( - allocator::malloc_or_wait(output.nbytes()), - output.shape(), - output.dtype()); - if (group->rank() == 0) { - group->send(input.data(), input.nbytes(), 1); - group->recv(temp.data(), output.nbytes(), 1); - sum_inplace(temp, output); - } else { - group->recv(temp.data(), output.nbytes(), 0); - group->send(input.data(), input.nbytes(), 0); - sum_inplace(temp, output); - } - } else { - if (group->rank() == 0) { - group->send(input.data(), input.nbytes(), 1); - group->recv(output.data(), output.nbytes(), 1); - sum_inplace(input, output); - } else { - group->recv(output.data(), output.nbytes(), 0); - group->send(input.data(), input.nbytes(), 0); - sum_inplace(input, output); - } + int size = group->size(); + int rank = group->rank(); + + if ((size & (size - 1)) != 0) { + throw std::runtime_error("Only powers of 2 are currently supported"); + } + + // If not inplace all reduce then copy the input to the output first. + if (input.data() != output.data()) { + std::memcpy(output.data(), input.data(), input.nbytes()); + } + + // Butterfly all reduce + for (int distance = 1; distance <= size / 2; distance *= 2) { + group->send_recv_sum(output, rank ^ distance); } } void all_gather(Group group_, const array& input_, array& output) { auto group = std::static_pointer_cast(group_.raw_group()); - if (group->size() != 2) { - throw std::runtime_error("Only pairwise communication supported for now"); - } array input = ensure_row_contiguous(input_); - if (group->rank() == 0) { - group->send(input.data(), input.nbytes(), 1); - group->recv(output.data() + input.nbytes(), input.nbytes(), 1); - std::memcpy(output.data(), input.data(), input.nbytes()); - } else { - group->recv(output.data(), input.nbytes(), 0); - group->send(input.data(), input.nbytes(), 0); - std::memcpy( - output.data() + input.nbytes(), - input.data(), - input.nbytes()); + std::future sent; + std::future received; + + int rank = group->rank(); + int size = group->size(); + + if ((size & (size - 1)) != 0) { + throw std::runtime_error("Only powers of 2 are currently supported"); } + + // Butterfly all gather + int peer = rank ^ 1; + if (peer < rank) { + received = group->recv_async( + output.data() + peer * input.nbytes(), input.nbytes(), peer); + sent = group->send_async(input.data(), input.nbytes(), peer); + } else { + sent = group->send_async(input.data(), input.nbytes(), peer); + received = group->recv_async( + output.data() + peer * input.nbytes(), input.nbytes(), peer); + } + std::memcpy( + output.data() + rank * input.nbytes(), + input.data(), + input.nbytes()); + + for (int distance = 2; distance <= size / 2; distance *= 2) { + sent.wait(); + received.wait(); + int peer = rank ^ distance; + int their_offset = peer & ~(distance - 1); + int our_offset = rank & ~(distance - 1); + + if (peer < rank) { + received = group->recv_async( + output.data() + their_offset * input.nbytes(), + distance * input.nbytes(), + peer); + sent = group->send_async( + output.data() + our_offset * input.nbytes(), + distance * input.nbytes(), + peer); + } else { + sent = group->send_async( + output.data() + our_offset * input.nbytes(), + distance * input.nbytes(), + peer); + received = group->recv_async( + output.data() + their_offset * input.nbytes(), + distance * input.nbytes(), + peer); + } + } + sent.wait(); + received.wait(); } void send(Group group_, const array& input_, int dst) {