From 34dd079a645aa0289ec824d0f0761d55b774a83b Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 27 Aug 2024 17:20:48 -0700 Subject: [PATCH] Start a sockets based distributed backend --- mlx/distributed/CMakeLists.txt | 16 +- mlx/distributed/sockets/CMakeLists.txt | 5 + mlx/distributed/sockets/sockets.cpp | 299 +++++++++++++++++++++++++ mlx/io/safetensors.cpp | 2 +- 4 files changed, 317 insertions(+), 5 deletions(-) create mode 100644 mlx/distributed/sockets/CMakeLists.txt create mode 100644 mlx/distributed/sockets/sockets.cpp diff --git a/mlx/distributed/CMakeLists.txt b/mlx/distributed/CMakeLists.txt index 4009196eb..4e57aa1f3 100644 --- a/mlx/distributed/CMakeLists.txt +++ b/mlx/distributed/CMakeLists.txt @@ -1,8 +1,16 @@ target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp) -if(MPI_FOUND AND MLX_BUILD_CPU) - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi) -else() - target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp) +if (MLX_BUILD_CPU) + if (MLX_CUSTOM_DISTRIBUTED) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/sockets) + elseif (MPI_FOUND) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi) + else() + target_sources( + mlx + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp + ) + endif() endif() diff --git a/mlx/distributed/sockets/CMakeLists.txt b/mlx/distributed/sockets/CMakeLists.txt new file mode 100644 index 000000000..e038d49f3 --- /dev/null +++ b/mlx/distributed/sockets/CMakeLists.txt @@ -0,0 +1,5 @@ +target_sources( + mlx + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/sockets.cpp +) diff --git a/mlx/distributed/sockets/sockets.cpp b/mlx/distributed/sockets/sockets.cpp new file mode 100644 index 000000000..1ea19a3a8 --- /dev/null +++ b/mlx/distributed/sockets/sockets.cpp @@ -0,0 +1,299 @@ +// Copyright © 2024 Apple Inc. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mlx/backend/common/copy.h" +#include "mlx/distributed/distributed.h" +#include "mlx/distributed/distributed_impl.h" + +using json = nlohmann::json; + +namespace mlx::core::distributed { + +namespace { + +template +void sum_inplace(const T* input, T* output, size_t N) { + while (N-- > 0) { + *output += *input; + input++; + output++; + } +} + +void sum_inplace(const array& input, array& output) { + switch (input.dtype()) { + case bool_: + return sum_inplace(input.data(), output.data(), input.size()); + case int8: + return sum_inplace( + input.data(), output.data(), input.size()); + case uint8: + return sum_inplace( + input.data(), output.data(), input.size()); + case int16: + return sum_inplace( + input.data(), output.data(), input.size()); + case uint16: + return sum_inplace( + input.data(), output.data(), input.size()); + case int32: + return sum_inplace( + input.data(), output.data(), input.size()); + case uint32: + return sum_inplace( + input.data(), output.data(), input.size()); + case int64: + return sum_inplace( + input.data(), output.data(), input.size()); + case uint64: + return sum_inplace( + input.data(), output.data(), input.size()); + case float16: + return sum_inplace( + input.data(), output.data(), input.size()); + case bfloat16: + return sum_inplace( + input.data(), output.data(), input.size()); + case float32: + return sum_inplace( + input.data(), output.data(), input.size()); + case complex64: + return sum_inplace( + input.data(), output.data(), input.size()); + } +} + +array ensure_row_contiguous(const array& arr) { + if (arr.flags().row_contiguous) { + return arr; + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy(arr, arr_copy, CopyType::General); + return arr_copy; + } +} + +struct address_t { + sockaddr_storage addr; + socklen_t len; + + const sockaddr* sockaddr() { + return (struct sockaddr*)&addr; + } +}; + +address_t parse_address(std::string ip, std::string port) { + struct addrinfo hints, *res; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_DGRAM; + + int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res); + if (status != 0) { + std::ostringstream msg; + msg << "Can't parse peer address " << ip << ":" << port; + throw std::runtime_error(msg.str()); + } + + address_t result; + memcpy(&result.addr, res->ai_addr, res->ai_addrlen); + result.len = res->ai_addrlen; + freeaddrinfo(res); + + return result; +} + +std::vector load_peers() { + std::vector peers; + std::ifstream f; + + if (const char* hostfile_buf = std::getenv("MLX_HOSTFILE")) { + f.open(hostfile_buf); + } else { + return peers; + } + + json hosts = json::parse(f); + for (auto& h : hosts) { + peers.push_back(std::move(parse_address( + h["ip"].template get(), + h["port"].template get()))); + } + + return peers; +} + +struct GroupImpl { + GroupImpl(std::vector peers, int rank, bool global) + : peers_(std::move(peers)), rank_(rank), global_(global) { + if (rank_ > 0 && rank_ >= peers_.size()) { + throw std::runtime_error( + "Rank cannot be larger than the size of the group"); + } + if (global_ && rank_ < peers_.size()) { + socket_fd_ = socket(AF_INET, SOCK_DGRAM, 0); + if (socket_fd_ < 0) { + std::ostringstream msg; + msg << "Couldn't create socket (error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + int success = + bind(socket_fd_, peers_[rank_].sockaddr(), peers_[rank_].len); + if (success < 0) { + std::ostringstream msg; + msg << "Couldn't bind socket (error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + } + } + ~GroupImpl() { + if (global_) { + close(socket_fd_); + } + } + + int rank() { + return rank_; + } + + int size() { + return std::max(peers_.size(), 1ul); + } + + void send(const char* buf, size_t len, int dst) { + ssize_t r = sendto( + socket_fd_, buf, len, 0, peers_[dst].sockaddr(), peers_[dst].len); + if (r < 0) { + throw std::runtime_error("Send failed."); + } + } + + void recv(char* buf, size_t len, int src) { + sockaddr_storage addr; + socklen_t addr_len; + while (len != 0) { + ssize_t r = + recvfrom(socket_fd_, buf, len, 0, (struct sockaddr*)&addr, &addr_len); + if (r <= 0) { + throw std::runtime_error("Recv failed"); + } + buf += r; + len -= r; + } + } + + private: + std::vector peers_; + int rank_; + bool global_; + int socket_fd_; +}; + +} // namespace + +bool is_available() { + return true; +} + +int Group::rank() { + return std::static_pointer_cast(group_)->rank(); +} + +int Group::size() { + return std::static_pointer_cast(group_)->size(); +} + +Group Group::split(int color, int key) { + throw std::runtime_error("Splitting not supported yet"); +} + +Group init(bool strict /* = false */) { + static std::shared_ptr global_group = nullptr; + + if (global_group == nullptr) { + auto peers = load_peers(); + int rank = 0; + if (const char* rank_buf = std::getenv("MLX_RANK")) { + rank = std::atoi(rank_buf); + } + if (peers.size() == 0) { + if (strict) { + throw std::runtime_error("Can't initialize distributed"); + } + } + global_group = std::make_shared(std::move(peers), rank, true); + } + return Group(global_group); +} + +namespace detail { + +Stream communication_stream() { + static Stream comm_stream = new_stream(Device::cpu); + return comm_stream; +} + +void all_sum(Group group_, const array& input_, array& output) { + auto group = std::static_pointer_cast(group_.raw_group()); + if (group->size() != 2) { + throw std::runtime_error("Only pairwise communication supported for now"); + } + array input = ensure_row_contiguous(input_); + if (input.data() == output.data()) { + throw std::runtime_error("Donation not supported"); + } else { + if (group->rank() == 0) { + group->send(input.data(), input.nbytes(), 1); + group->recv(output.data(), output.nbytes(), 1); + sum_inplace(input, output); + } else { + group->recv(output.data(), output.nbytes(), 0); + group->send(input.data(), input.nbytes(), 0); + sum_inplace(input, output); + } + } +} + +void all_gather(Group group_, const array& input_, array& output) { + auto group = std::static_pointer_cast(group_.raw_group()); + if (group->size() != 2) { + throw std::runtime_error("Only pairwise communication supported for now"); + } + array input = ensure_row_contiguous(input_); + if (group->rank() == 0) { + group->send(input.data(), input.nbytes(), 1); + group->recv(output.data() + input.nbytes(), input.nbytes(), 1); + std::memcpy(output.data(), input.data(), input.nbytes()); + } else { + group->recv(output.data(), input.nbytes(), 0); + group->send(input.data(), input.nbytes(), 0); + std::memcpy( + output.data() + input.nbytes(), + input.data(), + input.nbytes()); + } +} + +void send(Group group_, const array& input_, int dst) { + array input = ensure_row_contiguous(input_); + auto group = std::static_pointer_cast(group_.raw_group()); + group->send(input.data(), input.nbytes(), dst); +} + +void recv(Group group_, array& out, int src) { + auto group = std::static_pointer_cast(group_.raw_group()); + group->recv(out.data(), out.nbytes(), src); +} + +} // namespace detail + +} // namespace mlx::core::distributed diff --git a/mlx/io/safetensors.cpp b/mlx/io/safetensors.cpp index f022fb25f..7268c37be 100644 --- a/mlx/io/safetensors.cpp +++ b/mlx/io/safetensors.cpp @@ -1,5 +1,5 @@ // Copyright © 2023 Apple Inc. -// + #include #include