From b1a60b2d2d9f5e13f794008511d25536d983854c Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 9 Sep 2025 13:32:06 -0700 Subject: [PATCH] Initial working all reduce --- CMakeLists.txt | 4 + mlx/distributed/CMakeLists.txt | 5 + mlx/distributed/distributed.cpp | 14 +- mlx/distributed/ibv/CMakeLists.txt | 2 + mlx/distributed/ibv/ibv.cpp | 955 +++++++++++++++++++++++++++++ mlx/distributed/ibv/ibv.h | 12 + mlx/distributed/ring/ring.cpp | 193 +----- mlx/distributed/utils.cpp | 189 ++++++ mlx/distributed/utils.h | 62 ++ 9 files changed, 1273 insertions(+), 163 deletions(-) create mode 100644 mlx/distributed/ibv/CMakeLists.txt create mode 100644 mlx/distributed/ibv/ibv.cpp create mode 100644 mlx/distributed/ibv/ibv.h create mode 100644 mlx/distributed/utils.cpp create mode 100644 mlx/distributed/utils.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 3487c22c5..9117342e5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -119,6 +119,10 @@ if(MLX_BUILD_METAL) COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version" OUTPUT_VARIABLE MACOS_SDK_VERSION OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY) + execute_process( + COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-path" + OUTPUT_VARIABLE CMAKE_OSX_SYSROOT + OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY) if(${MACOS_SDK_VERSION} LESS 14.0) message( diff --git a/mlx/distributed/CMakeLists.txt b/mlx/distributed/CMakeLists.txt index b7762f6a7..81935ac64 100644 --- a/mlx/distributed/CMakeLists.txt +++ b/mlx/distributed/CMakeLists.txt @@ -4,6 +4,11 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp) +if(MLX_BUILD_CPU AND NOT WIN32) + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp) +endif() + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ibv) diff --git a/mlx/distributed/distributed.cpp b/mlx/distributed/distributed.cpp index 2f5ea8029..74637edb5 100644 --- a/mlx/distributed/distributed.cpp +++ b/mlx/distributed/distributed.cpp @@ -5,6 +5,7 @@ #include "mlx/backend/cuda/cuda.h" #include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed_impl.h" +#include "mlx/distributed/ibv/ibv.h" #include "mlx/distributed/mpi/mpi.h" #include "mlx/distributed/nccl/nccl.h" #include "mlx/distributed/ring/ring.h" @@ -102,7 +103,8 @@ class EmptyGroup : public GroupImpl { } // namespace detail bool is_available() { - return mpi::is_available() || ring::is_available() || nccl::is_available(); + return mpi::is_available() || ring::is_available() || nccl::is_available() || + ibv::is_available(); } int Group::rank() const { @@ -135,6 +137,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) { group = ring::init(strict); } else if (bk == "nccl") { group = nccl::init(strict); + } else if (bk == "ibv") { + group = ibv::init(strict); } else if (bk == "any") { if (mlx::core::cu::is_available()) { group = nccl::init(false); @@ -148,13 +152,17 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) { group = mpi::init(false); bk_ = "mpi"; } + if (group == nullptr) { + group = ibv::init(false); + bk_ = "ibv"; + } if (group == nullptr && strict) { throw std::runtime_error("[distributed] Couldn't initialize any backend"); } } else { std::ostringstream msg; - msg << "[distributed] The only valid values for backend are 'any', 'mpi' " - << "and 'ring' but '" << bk << "' was provided."; + msg << "[distributed] The only valid values for backend are 'any', 'mpi', 'nccl', " + << "'ibv' and 'ring' but '" << bk << "' was provided."; throw std::invalid_argument(msg.str()); } diff --git a/mlx/distributed/ibv/CMakeLists.txt b/mlx/distributed/ibv/CMakeLists.txt new file mode 100644 index 000000000..2ed69c075 --- /dev/null +++ b/mlx/distributed/ibv/CMakeLists.txt @@ -0,0 +1,2 @@ +target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ibv.cpp) +target_link_libraries(mlx PRIVATE rdma) diff --git a/mlx/distributed/ibv/ibv.cpp b/mlx/distributed/ibv/ibv.cpp new file mode 100644 index 000000000..4b829a182 --- /dev/null +++ b/mlx/distributed/ibv/ibv.cpp @@ -0,0 +1,955 @@ +// Copyright © 2025 Apple Inc. + +#include +#include +#include +#include + +#include + +#include "mlx/backend/cpu/encoder.h" +#include "mlx/distributed/distributed_impl.h" +#include "mlx/distributed/utils.h" +#include "mlx/dtype_utils.h" + +constexpr const char* IBV_TAG = "[ibv]"; +constexpr int NUM_BUFFERS = 2; +constexpr int BUFFER_SIZE = 4096; +constexpr int MAX_SEND_WR = 32; +constexpr int MAX_RECV_WR = 32; +constexpr int SEND_WR = 1; +constexpr int RECV_WR = 2; +constexpr int MAX_PEERS = 8; + +using GroupImpl = mlx::core::distributed::detail::GroupImpl; +using json = nlohmann::json; +namespace detail = mlx::core::distributed::detail; +namespace allocator = mlx::core::allocator; + +template +struct is_container : std::false_type {}; + +template +struct is_container< + T, + std::void_t> + : std::true_type {}; + +std::ostream& operator<<(std::ostream& os, const ibv_gid& gid) { + os << std::hex << std::setfill('0'); + for (int i = 0; i < 16; i += 2) { + uint16_t part = (gid.raw[i] << 8) | gid.raw[i + 1]; + os << std::setw(4) << part; + if (i < 14) + os << ":"; + } + os << std::dec; + return os; +} + +void* page_aligned_alloc(size_t num_bytes) { + static size_t page_size = sysconf(_SC_PAGESIZE); + void * buf; + if (posix_memalign(&buf, page_size, num_bytes)) { + return nullptr; + } + return buf; +} + +/** + * Contains the information that defines a destination to a remote device. + * Basically we can compute our own destination and share it with remote hosts + * over the side channel. + */ +struct Destination { + int local_id; + int queue_pair_number; + int packet_sequence_number; + ibv_gid global_identifier; +}; + +/** + * A buffer that can be registered to a number of protection domains. + */ +class SharedBuffer { + public: + SharedBuffer(size_t num_bytes) + : data_(page_aligned_alloc(num_bytes)), + num_bytes_(num_bytes) {} + ~SharedBuffer() { + for (auto& [pd, mr] : memory_regions_) { + ibv_dereg_mr(mr); + } + if (data_ != nullptr) { + std::free(data_); + } + } + + SharedBuffer(const SharedBuffer&) = delete; + SharedBuffer& operator=(const SharedBuffer&) = delete; + SharedBuffer(SharedBuffer&& b) + : data_(nullptr), num_bytes_(0) { + std::swap(data_, b.data_); + std::swap(num_bytes_, b.num_bytes_); + std::swap(memory_regions_, b.memory_regions_); + } + + void register_to_protection_domain(ibv_pd* protection_domain) { + auto [it, inserted] = memory_regions_.insert({protection_domain, nullptr}); + if (!inserted) { + throw std::runtime_error( + "[ibv] Buffer can be registered once per protection domain"); + } + + it->second = ibv_reg_mr( + protection_domain, + data_, + num_bytes_, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | + IBV_ACCESS_REMOTE_WRITE); + if (!it->second) { + throw std::runtime_error("[ibv] Register memory region failed"); + } + } + + size_t size() const { + return num_bytes_; + } + + uint32_t local_key(ibv_pd* protection_domain) const { + return memory_regions_.at(protection_domain)->lkey; + } + + ibv_sge to_scatter_gather_entry(ibv_pd* protection_domain) const { + ibv_sge entry; + entry.addr = reinterpret_cast(data_); + entry.length = size(); + entry.lkey = local_key(protection_domain); + return entry; + } + + template + T* data() { + return static_cast(data_); + } + + template + T* begin() { + return static_cast(data_); + } + + template + T* end() { + return static_cast(data_) + size() / sizeof(T); + } + + private: + void* data_; + size_t num_bytes_; + std::unordered_map memory_regions_; +}; + +/** + * Manipulates an RDMA connection. Enables (among other things) + * + * - Creating a queue pair + * - Sending and receiving + * - Checking completion + */ +struct Connection { + ibv_context* ctx; + ibv_pd* protection_domain; + ibv_cq* completion_queue; + ibv_qp* queue_pair; + Destination src; // holds the local information + + Connection(ibv_context* ctx_) + : ctx(ctx_), + protection_domain(nullptr), + completion_queue(nullptr), + queue_pair(nullptr) { + src.local_id = -1; + } + + Connection(Connection&& c) : Connection(nullptr) { + std::swap(ctx, c.ctx); + std::swap(protection_domain, c.protection_domain); + std::swap(completion_queue, c.completion_queue); + std::swap(queue_pair, c.queue_pair); + std::swap(src, c.src); + } + + Connection(const Connection&) = delete; + Connection& operator=(Connection&) = delete; + + ~Connection() { + if (queue_pair != nullptr) { + ibv_destroy_qp(queue_pair); + } + if (completion_queue != nullptr) { + ibv_destroy_cq(completion_queue); + } + if (protection_domain != nullptr) { + ibv_dealloc_pd(protection_domain); + } + if (ctx != nullptr) { + ibv_close_device(ctx); + } + } + + void allocate_protection_domain() { + protection_domain = ibv_alloc_pd(ctx); + if (protection_domain == nullptr) { + throw std::runtime_error("[ibv] Couldn't allocate protection domain"); + } + } + + void create_completion_queue(int num_entries) { + completion_queue = ibv_create_cq(ctx, num_entries, nullptr, nullptr, 0); + if (completion_queue == nullptr) { + throw std::runtime_error("[ibv] Couldn't create completion queue"); + } + } + + void create_queue_pair() { + ibv_qp_init_attr init_attr; + init_attr.qp_context = ctx; + init_attr.qp_context = ctx; + init_attr.send_cq = completion_queue; + init_attr.recv_cq = completion_queue; + init_attr.srq = nullptr; + init_attr.cap.max_send_wr = MAX_SEND_WR; + init_attr.cap.max_recv_wr = MAX_RECV_WR; + init_attr.cap.max_send_sge = 1; + init_attr.cap.max_recv_sge = 1; + init_attr.cap.max_inline_data = 0; + init_attr.qp_type = IBV_QPT_UC; + init_attr.sq_sig_all = 0; + + queue_pair = ibv_create_qp(protection_domain, &init_attr); + + if (queue_pair == nullptr) { + throw std::runtime_error("[ibv] Couldn't create queue pair"); + } + } + + const Destination& info() { + if (queue_pair == nullptr || src.local_id >= 0) { + return src; + } + + ibv_port_attr port_attr; + ibv_query_port(ctx, 1, &port_attr); + ibv_gid gid; + ibv_query_gid(ctx, 1, 1, &gid); + + src.local_id = port_attr.lid; + src.queue_pair_number = queue_pair->qp_num; + src.packet_sequence_number = 7; // TODO: Change to sth random + src.global_identifier = gid; + + return src; + } + + void queue_pair_init() { + ibv_qp_attr attr = {}; + attr.qp_state = IBV_QPS_INIT; + attr.port_num = 1; + attr.pkey_index = 0; + attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | + IBV_ACCESS_REMOTE_WRITE; + + int mask = + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS; + + if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) { + std::ostringstream msg; + msg << "[ibv] Changing queue pair to INIT failed with errno " << status; + throw std::invalid_argument(msg.str()); + } + } + + void queue_pair_rtr(const Destination& dst) { + ibv_qp_attr attr = {}; + memset(&attr, 0, sizeof(attr)); + attr.qp_state = IBV_QPS_RTR; + attr.path_mtu = IBV_MTU_1024; + attr.rq_psn = dst.packet_sequence_number; + attr.dest_qp_num = dst.queue_pair_number; + attr.ah_attr.dlid = dst.local_id; + attr.ah_attr.sl = 0; + attr.ah_attr.src_path_bits = 0; + attr.ah_attr.port_num = 1; + attr.ah_attr.is_global = 0; + + if (dst.global_identifier.global.interface_id) { + attr.ah_attr.is_global = 1; + attr.ah_attr.grh.hop_limit = 1; + attr.ah_attr.grh.dgid = dst.global_identifier; + attr.ah_attr.grh.sgid_index = 1; + } + + int mask = IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | + IBV_QP_RQ_PSN; + + if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) { + std::ostringstream msg; + msg << "[ibv] Changing queue pair to RTR failed with errno " << status; + throw std::invalid_argument(msg.str()); + } + } + + void queue_pair_rts() { + ibv_qp_attr attr = {}; + attr.qp_state = IBV_QPS_RTS; + attr.sq_psn = src.packet_sequence_number; + + int mask = IBV_QP_STATE | IBV_QP_SQ_PSN; + + if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) { + std::ostringstream msg; + msg << "[ibv] Changing queue pair to RTS failed with errno " << status; + throw std::invalid_argument(msg.str()); + } + } + + void post_send(const SharedBuffer& buff, uint64_t work_request_id) { + ibv_send_wr work_request, *bad_work_request; + + auto entry = buff.to_scatter_gather_entry(protection_domain); + work_request.wr_id = work_request_id; + work_request.sg_list = &entry; + work_request.num_sge = 1; + work_request.opcode = IBV_WR_SEND; + work_request.send_flags = IBV_SEND_SIGNALED; + work_request.next = nullptr; + + if (int status = + ibv_post_send(queue_pair, &work_request, &bad_work_request); + status != 0) { + std::ostringstream msg; + msg << "[ibv] Send failed with error code " << status; + throw std::invalid_argument(msg.str()); + } + } + + void post_recv(const SharedBuffer& buff, uint64_t work_request_id) { + ibv_recv_wr work_request, *bad_work_request; + + auto entry = buff.to_scatter_gather_entry(protection_domain); + work_request.wr_id = work_request_id; + work_request.sg_list = &entry; + work_request.num_sge = 1; + work_request.next = nullptr; + + if (int status = + ibv_post_recv(queue_pair, &work_request, &bad_work_request); + status != 0) { + std::ostringstream msg; + msg << "[ibv] Recv failed with error code " << status; + throw std::invalid_argument(msg.str()); + } + } +}; + +/** + * Implement a TCP side channel to exchange information about the RDMA + * connections. + * + * Implements a simple all gather where every node sends to rank 0 and rank 0 + * broadcasts to every node. + */ +class SideChannel { + public: + SideChannel(int rank, int size, const char* addr) : rank_(rank), size_(size) { + auto address = detail::parse_address(addr); + + if (rank_ == 0) { + detail::TCPSocket server(IBV_TAG); + server.listen(IBV_TAG, address); + + for (int i = 0; i < size - 1; i++) { + sockets_.push_back(server.accept(IBV_TAG)); + } + } else { + sockets_.push_back(detail::TCPSocket::connect( + IBV_TAG, address, 4, 1000, [](int attempt, int wait) { + std::cerr << IBV_TAG << " Connection attempt " << attempt + << " waiting " << wait << " ms" << std::endl; + })); + } + } + + SideChannel(const SideChannel&) = delete; + SideChannel& operator=(const SideChannel&) = delete; + + SideChannel(SideChannel&& sc) + : rank_(sc.rank_), size_(sc.size_), sockets_(std::move(sc.sockets_)) { + sc.rank_ = -1; + sc.size_ = -1; + } + + template + std::vector all_gather(const T& v) { + std::vector result(size_); + + // T is a container of stuff like std::vector or std::string + if constexpr (is_container::value) { + using U = typename T::value_type; + + // Share the lengths first and set the communication size to be the + // maximum length of the containers. + auto lengths = all_gather(v.size()); + auto max_len = *std::max_element(lengths.begin(), lengths.end()); + for (auto& s : result) { + s.resize(max_len); + } + + // All gather of length max_len + if (rank_ == 0) { + std::copy(v.begin(), v.end(), result[rank_].begin()); + for (int i = 1; i < size_; i++) { + sockets_[i - 1].recv(IBV_TAG, result[i].data(), sizeof(U) * max_len); + } + for (int i = 1; i < size_; i++) { + for (int j = 0; j < size_; j++) { + sockets_[i - 1].send( + IBV_TAG, result[j].data(), sizeof(U) * max_len); + } + } + } else { + std::copy(v.begin(), v.end(), result[rank_].begin()); + sockets_[0].send(IBV_TAG, result[rank_].data(), sizeof(U) * max_len); + for (int i = 0; i < size_; i++) { + sockets_[0].recv(IBV_TAG, result[i].data(), sizeof(U) * max_len); + } + } + + // Resize the outputs back to the original length + for (int i = 0; i < size_; i++) { + result[i].resize(lengths[i]); + } + } + + // T is a scalar + else { + if (rank_ == 0) { + result[rank_] = v; + for (int i = 1; i < size_; i++) { + sockets_[i - 1].recv(IBV_TAG, &result[i], sizeof(T)); + } + for (int i = 1; i < size_; i++) { + sockets_[i - 1].send(IBV_TAG, result.data(), size_ * sizeof(T)); + } + } else { + sockets_[0].send(IBV_TAG, &v, sizeof(T)); + sockets_[0].recv(IBV_TAG, result.data(), size_ * sizeof(T)); + } + } + + return result; + } + + private: + int rank_; + int size_; + std::vector sockets_; +}; + +/** + * Manages a set of connections. Among other things it uses a side channel to + * exchange the necessary information and then configure the connections to be + * ready for RDMA operations. + */ +class ConnectionManager { + public: + ConnectionManager( + int rank, + const std::vector& device_names, + const char* coordinator_addr) + : rank_(rank), + size_(device_names.size()), + side_channel_(rank_, size_, coordinator_addr) { + create_contexts(device_names); + if (connections_[rank_].ctx != nullptr) { + throw std::runtime_error("[ibv] Malformed device file"); + } + } + + int rank() const { + return rank_; + } + + int size() const { + return size_; + } + + /** + * Performs the connection initialization. Namely, after this call all + * Connection objects should have a queue pair in RTS state. + */ + void initialize(int num_buffers, size_t num_bytes) { + // Create the queue pairs + for (auto& conn : connections_) { + if (conn.ctx == nullptr) { + continue; + } + conn.allocate_protection_domain(); + conn.create_completion_queue(MAX_SEND_WR + MAX_RECV_WR); + conn.create_queue_pair(); + } + + allocate_buffers(num_buffers, num_bytes); + + // Gather the information to be exchanged + std::vector info; + for (auto& conn : connections_) { + info.emplace_back(conn.info()); + } + auto all_infos = side_channel_.all_gather(info); + + // Transition queue pairs to RTS + for (int peer = 0; peer < size_; peer++) { + if (peer == rank_) { + continue; + } + auto peer_info = all_infos[peer][rank_]; + connections_[peer].queue_pair_init(); + connections_[peer].queue_pair_rtr(peer_info); + connections_[peer].queue_pair_rts(); + } + } + + void allocate_buffers(int num_buffers, size_t num_bytes) { + // Deregister any buffers and free the memory + buffers_.clear(); + + // Allocate the memory + for (int i = 0; i < num_buffers; i++) { + for (int j = 0; j < size_; j++) { + buffers_.emplace_back(num_bytes); + } + } + + for (int i = 0; i < num_buffers; i++) { + for (int j = 0; j < size_; j++) { + // This is our send buffer so register it with all pds so we can send + // it to all connected devices. + if (j == rank_) { + for (auto& conn : connections_) { + if (conn.ctx != nullptr) { + buffers_[i * size_ + j].register_to_protection_domain( + conn.protection_domain); + } + } + } + + // This is the recv buffer from rank j so register it to rank j's + // protection domain. + else { + buffers_[i * size_ + j].register_to_protection_domain( + connections_[j].protection_domain); + } + } + } + } + + void send_to(int rank, int buff) { + connections_[rank].post_send( + buffers_[buff * size_ + rank_], + SEND_WR << 16 | buff << 8 | rank); + } + + void recv_from(int rank, int buff) { + connections_[rank].post_recv( + buffers_[buff * size_ + rank], RECV_WR << 16 | buff << 8 | rank); + } + + /** + * Poll all connections and save the work completions and return the + * corresponding length. + */ + int poll(int num_completions, ibv_wc* work_completions) { + int completions = 0; + for (int r = 0; r < size_; r++) { + if (r == rank_) { + continue; + } + if (completions >= num_completions) { + return completions; + } + + int c = ibv_poll_cq( + connections_[r].completion_queue, + num_completions - completions, + work_completions + completions); + + completions += c; + } + return completions; + } + + /** + * + */ + int poll(int rank, int num_completions, ibv_wc* work_completions) { + return ibv_poll_cq( + connections_[rank].completion_queue, + num_completions, + work_completions); + } + + SharedBuffer& send_buffer(int buff) { + return buffers_[buff * size_ + rank_]; + } + + SharedBuffer& buffer(int rank, int buff) { + return buffers_[buff * size_ + rank]; + } + + void barrier() { + side_channel_.all_gather(0); + } + + private: + void create_contexts(const std::vector& device_names) { + int num_devices = 0; + ibv_device** devices = ibv_get_device_list(&num_devices); + for (auto& name : device_names) { + // Empty so add a nullptr context + if (name.empty()) { + connections_.emplace_back(nullptr); + continue; + } + + // Search for the name and try to open the device + for (int i = 0; i < num_devices; i++) { + if (name == ibv_get_device_name(devices[i])) { + auto ctx = ibv_open_device(devices[i]); + if (ctx == nullptr) { + std::ostringstream msg; + msg << "[ibv] Could not open device " << name; + throw std::runtime_error(msg.str()); + } + connections_.emplace_back(ctx); + break; + } + } + } + ibv_free_device_list(devices); + } + + int rank_; + int size_; + SideChannel side_channel_; + std::vector connections_; + std::vector buffers_; +}; + +template +struct SumOp { + void operator()(const T* input, T* output, size_t N) const { + while (N-- > 0) { + *output += *input; + input++; + output++; + } + } +}; + +std::vector load_device_names(int rank, const char* dev_file) { + std::vector device_names; + std::ifstream f(dev_file); + + json devices = json::parse(f); + devices = devices[rank]; + for (auto it = devices.begin(); it != devices.end(); it++) { + std::string n; + if (!it->is_null()) { + n = *it; + } + device_names.emplace_back(std::move(n)); + } + + return device_names; +} + +namespace mlx::core::distributed::ibv { + +class IBVGroup : public GroupImpl { + public: + IBVGroup(ConnectionManager cm) : cm_(std::move(cm)), rank_(cm.rank()), size_(cm.size()) {} + + Stream communication_stream(StreamOrDevice s) override { + return to_stream(s, Device::cpu); + } + + int rank() override { + return cm_.rank(); + } + + int size() override { + return cm_.size(); + } + + void all_sum(const array& input, array& output, Stream stream) override { + dispatch_all_types(output.dtype(), [&](auto type_tag) { + using T = MLX_GET_TYPE(type_tag); + all_reduce(input, output, stream, SumOp{}); + }); + } + + void all_max(const array& input, array& output, Stream stream) override { + } + + void all_min(const array& input, array& output, Stream stream) override { + } + + void all_gather(const array& input, array& output, Stream stream) override { + } + + void send(const array& input, int dst, Stream stream) override { + } + + void recv(array& out, int src, Stream stream) override { + } + + std::shared_ptr split(int color, int key = -1) override { + throw std::runtime_error("[ibv] Group split not supported."); + } + + private: + void post_recv_all(int buffer) { + for (int i = 0; i < size_; i++) { + if (i == rank_) { + continue; + } + cm_.recv_from(i, buffer); + } + } + + void post_send_all(int buffer) { + for (int i = 0; i < size_; i++) { + if (i == rank_) { + continue; + } + cm_.send_to(i, buffer); + } + } + + template + void all_reduce( + const array& input, + array& output, + Stream stream, + ReduceOp reduce_op) { + auto in_ptr = input.data(); + auto out_ptr = output.data(); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_output_array(output); + encoder.dispatch([in_ptr, out_ptr, size = input.size(), this, reduce_op]() { + // If not inplace all reduce then copy the input to the output first + if (in_ptr != out_ptr) { + std::memcpy(out_ptr, in_ptr, size * sizeof(T)); + } + + // Fully connected all reduce + T* data = out_ptr; + constexpr int64_t N = BUFFER_SIZE / sizeof(T); + int64_t total = static_cast(size); + int64_t offset = N; + int a = 0, b = 1; + + int mask_init = 1 << rank_; + int mask_target = (1 << size_) - 1; + + // Handle the first piece of data. + post_recv_all(a); + std::copy( + data, + data + std::min(N, total), + cm_.send_buffer(a).begin()); + post_send_all(a); + + int mask_a_send = mask_init; + int mask_a_recv = mask_init; + int mask_b_send = mask_init; + int mask_b_recv = mask_init; + + while (offset < total) { + // While the previous chunk is in flight copy to the next send buffer + std::copy( + data + offset, + data + std::min(offset + N, total), + cm_.send_buffer(b).begin()); + + // Send if the previous send is already done + for (int i = 0, m = 1; i < size_; i++, m *= 2) { + if (i == rank_) { + continue; + } + if (mask_a_send & m) { + cm_.send_to(i, b); + } + } + + // Recv the next buffer if the previous one is already done + for (int i = 0, m = 1; i < size_; i++, m *= 2) { + if (i == rank_) { + continue; + } + if (mask_a_recv & m) { + cm_.recv_from(i, b); + reduce_op( + cm_.buffer(i, a).begin(), + data + std::max(offset - N, 0LL), + std::min(N, total - offset)); + } + } + + // Loop until this chunk is all done + while (mask_a_send != mask_target || mask_a_recv != mask_target) { + ibv_wc wc[8]; + int n = cm_.poll(8, wc); + for (int i = 0; i < n; i++) { + int work_type = wc[i].wr_id >> 16; + int buff = (wc[i].wr_id >> 8) & 0xff; + int rank = wc[i].wr_id & 0xff; + + if (work_type == SEND_WR) { + if (buff == a) { + cm_.send_to(rank, b); + mask_a_send |= 1 << rank; + } else { + mask_b_send |= 1 << rank; + } + } else { + if (buff == a) { + cm_.recv_from(rank, b); + mask_a_recv |= 1 << rank; + reduce_op( + cm_.buffer(rank, a).begin(), + data + std::max(offset - N, 0LL), + std::min(N, total - offset)); + } else { + mask_b_recv |= 1 << rank; + } + } + } + } + + std::swap(a, b); + mask_a_send = mask_b_send; + mask_a_recv = mask_b_recv; + mask_b_send = mask_init; + mask_b_recv = mask_init; + offset += N; + } + + { + while (mask_a_send != mask_target || mask_a_recv != mask_target) { + ibv_wc wc[8]; + int n = cm_.poll(8, wc); + for (int i = 0; i < n; i++) { + int work_type = wc[i].wr_id >> 16; + int buff = (wc[i].wr_id >> 8) & 0xff; + int rank = wc[i].wr_id & 0xff; + + if (work_type == SEND_WR) { + mask_a_send |= 1 << rank; + } else { + mask_a_recv |= 1 << rank; + reduce_op( + cm_.buffer(rank, a).begin(), + data + offset - N, + std::min(N, total + N - offset)); + } + } + } + } + }); + } + + ConnectionManager cm_; + int rank_; + int size_; +}; + +bool is_available() { + return true; +} + +std::shared_ptr init(bool strict /* = false */) { + const char* dev_file = std::getenv("MLX_IBV_DEVICES"); + const char* coordinator = std::getenv("MLX_IBV_COORDINATOR"); + const char* rank_str = std::getenv("MLX_RANK"); + const char* ring_verbose = std::getenv("MLX_IBV_VERBOSE"); + + if (!dev_file || !coordinator || !rank_str) { + if (strict) { + std::ostringstream msg; + msg << "[ibv] You need to provide via environment variables a rank (MLX_RANK), " + << "a device file (MLX_IBV_DEVICES) and a coordinator ip/port (MLX_IBV_COORDINATOR) " + << "but provided MLX_RANK=\"" << ((rank_str) ? rank_str : "") + << "\", MLX_IBV_DEVICES=\"" << ((dev_file) ? dev_file : "") + << "\" and MLX_IBV_COORDINATOR=\"" + << ((coordinator) ? coordinator : ""); + throw std::runtime_error(msg.str()); + } + return nullptr; + } + + auto rank = std::atoi(rank_str); + auto device_names = load_device_names(rank, dev_file); + + auto cm = ConnectionManager(rank, device_names, coordinator); + if (cm.size() > MAX_PEERS) { + std::ostringstream msg; + msg << "[ibv] The maximum number of supported peers is " + << MAX_PEERS << " but " << cm.size() << " was provided"; + throw std::runtime_error(msg.str()); + } + + cm.initialize(NUM_BUFFERS, BUFFER_SIZE); + cm.barrier(); + + return std::make_shared(std::move(cm)); + + //cm.recv_from(rank ^ 1, 0); + //cm.barrier(); + //std::fill( + // cm.buffer(rank, 0).begin(), + // cm.buffer(rank, 0).end(), + // 2 * rank + 1); + //cm.barrier(); + //cm.send_to(rank ^ 1, 0); + + //ibv_wc wc[8]; + //bool sent = false, recvd = false; + //while (!(sent && recvd)) { + // int num_completions = cm.poll(8, wc); + // for (int i = 0; i < num_completions; i++) { + // if (wc[i].status != IBV_WC_SUCCESS) { + // std::cout << "Error " << wc[i].status << std::endl; + // } + + // int work_type = wc[i].wr_id >> 16; + // int buff = (wc[i].wr_id >> 8) & 0xff; + // int rank = wc[i].wr_id & 0xff; + + // std::cout << work_type << " " << buff << " " << rank << std::endl; + // print_wc(wc[i]); + + // sent |= (work_type == SEND_WR); + // recvd |= (work_type == RECV_WR); + // } + //} + + //std::cout << rank << " ours " << *cm.buffer(rank, 0).data() << std::endl; + //std::cout << rank << " theirs " << *cm.buffer(rank ^ 1, 0).data() + // << std::endl; + + //return nullptr; +} + +} // namespace mlx::core::distributed::ibv diff --git a/mlx/distributed/ibv/ibv.h b/mlx/distributed/ibv/ibv.h new file mode 100644 index 000000000..7aca41e58 --- /dev/null +++ b/mlx/distributed/ibv/ibv.h @@ -0,0 +1,12 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/distributed/distributed.h" + +namespace mlx::core::distributed::ibv { + +using GroupImpl = mlx::core::distributed::detail::GroupImpl; + +bool is_available(); +std::shared_ptr init(bool strict = false); + +} // namespace mlx::core::distributed::ibv diff --git a/mlx/distributed/ring/ring.cpp b/mlx/distributed/ring/ring.cpp index 23537c4d7..b3f293458 100644 --- a/mlx/distributed/ring/ring.cpp +++ b/mlx/distributed/ring/ring.cpp @@ -1,9 +1,6 @@ // Copyright © 2024 Apple Inc. -#include #include -#include -#include #include #include #include @@ -22,6 +19,7 @@ #include "mlx/backend/cpu/encoder.h" #include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed_impl.h" +#include "mlx/distributed/utils.h" #include "mlx/threadpool.h" #ifndef SOL_TCP @@ -94,6 +92,7 @@ constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024; constexpr const size_t ALL_SUM_BUFFERS = 2; constexpr const int CONN_ATTEMPTS = 5; constexpr const int CONN_WAIT = 1000; +constexpr const char* RING_TAG = "[ring]"; using GroupImpl = mlx::core::distributed::detail::GroupImpl; using json = nlohmann::json; @@ -296,55 +295,6 @@ class CommunicationThreads { std::unordered_map threads_; }; -struct address_t { - sockaddr_storage addr; - socklen_t len; - - const sockaddr* get() const { - return (struct sockaddr*)&addr; - } -}; - -/** - * Parse a sockaddr from an ip and port provided as strings. - */ -address_t parse_address(const std::string& ip, const std::string& port) { - struct addrinfo hints, *res; - memset(&hints, 0, sizeof(hints)); - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - - int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res); - if (status != 0) { - std::ostringstream msg; - msg << "Can't parse address " << ip << ":" << port; - throw std::runtime_error(msg.str()); - } - - address_t result; - memcpy(&result.addr, res->ai_addr, res->ai_addrlen); - result.len = res->ai_addrlen; - freeaddrinfo(res); - - return result; -} - -/** - * Parse a sockaddr provided as an : string. - */ -address_t parse_address(const std::string& ip_port) { - auto colon = ip_port.find(":"); - if (colon == std::string::npos) { - std::ostringstream msg; - msg << "Can't parse address " << ip_port; - throw std::runtime_error(msg.str()); - } - std::string ip(ip_port.begin(), ip_port.begin() + colon); - std::string port(ip_port.begin() + colon + 1, ip_port.end()); - - return parse_address(ip, port); -} - /** * Load all addresses from the json hostfile. The hostfile is a list of * addresses in order of rank. For each rank there can be many addresses so @@ -357,15 +307,15 @@ address_t parse_address(const std::string& ip_port) { * ["ip3:5000", "ip3:5001"], * ] */ -std::vector> load_nodes(const char* hostfile) { - std::vector> nodes; +std::vector> load_nodes(const char* hostfile) { + std::vector> nodes; std::ifstream f(hostfile); json hosts = json::parse(f); for (auto& h : hosts) { - std::vector host; + std::vector host; for (auto& ips : h) { - host.push_back(parse_address(ips.get())); + host.push_back(std::move(detail::parse_address(ips.get()))); } nodes.push_back(std::move(host)); } @@ -377,73 +327,15 @@ std::vector> load_nodes(const char* hostfile) { * Create a socket and accept one connection for each of the provided * addresses. */ -std::vector accept_connections(const std::vector& addresses) { +std::vector accept_connections( + const std::vector& addresses) { std::vector sockets; int success; for (auto& address : addresses) { - // Create the socket to wait for connections from the peers - int sock = socket(AF_INET, SOCK_STREAM, 0); - if (sock < 0) { - std::ostringstream msg; - msg << "[ring] Couldn't create socket (error: " << errno << ")"; - throw std::runtime_error(msg.str()); - } - - // Make sure we can launch immediately after shutdown by setting the - // reuseaddr option so that we don't get address already in use errors - int enable = 1; - success = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)); - if (success < 0) { - shutdown(sock, 2); - close(sock); - std::ostringstream msg; - msg << "[ring] Couldn't enable reuseaddr (error: " << errno << ")"; - throw std::runtime_error(msg.str()); - } - success = setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int)); - if (success < 0) { - shutdown(sock, 2); - close(sock); - std::ostringstream msg; - msg << "[ring] Couldn't enable reuseport (error: " << errno << ")"; - throw std::runtime_error(msg.str()); - } - - // Bind the socket to the address and port - success = bind(sock, address.get(), address.len); - if (success < 0) { - shutdown(sock, 2); - close(sock); - std::ostringstream msg; - msg << "[ring] Couldn't bind socket (error: " << errno << ")"; - throw std::runtime_error(msg.str()); - } - - // Wait for connections - success = listen(sock, 0); - if (success < 0) { - shutdown(sock, 2); - close(sock); - std::ostringstream msg; - msg << "[ring] Couldn't listen (error: " << errno << ")"; - throw std::runtime_error(msg.str()); - } - - int peer_socket = accept(sock, nullptr, nullptr); - if (peer_socket < 0) { - shutdown(sock, 2); - close(sock); - std::ostringstream msg; - msg << "[ring] Accept failed (error: " << errno << ")"; - throw std::runtime_error(msg.str()); - } - - // Close the listening socket - shutdown(sock, 2); - close(sock); - - sockets.push_back(peer_socket); + detail::TCPSocket socket(RING_TAG); + socket.listen(RING_TAG, address); + sockets.push_back(socket.accept(RING_TAG)); } return sockets; @@ -454,55 +346,33 @@ std::vector accept_connections(const std::vector& addresses) { * provided addresses. */ std::vector make_connections( - const std::vector& addresses, + const std::vector& addresses, bool verbose) { std::vector sockets; int success; for (auto& address : addresses) { - int sock; - - // Attempt to connect to the peer CONN_ATTEMPTS times with exponential - // backoff. TODO: Do we need that? - for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) { - // Create the socket - sock = socket(AF_INET, SOCK_STREAM, 0); - if (sock < 0) { - std::ostringstream msg; - msg << "[ring] Couldn't create socket (error: " << errno << ")"; - throw std::runtime_error(msg.str()); - } - - if (attempt > 0) { - int wait = (1 << (attempt - 1)) * CONN_WAIT; - log_info( - verbose, - "Attempt", - attempt, - "wait", - wait, - "ms (error:", - errno, - ")"); - std::this_thread::sleep_for(std::chrono::milliseconds(wait)); - } - - success = connect(sock, address.get(), address.len); - if (success == 0) { - break; - } - } - if (success < 0) { - std::ostringstream msg; - msg << "[ring] Couldn't connect (error: " << errno << ")"; - throw std::runtime_error(msg.str()); - } - - sockets.push_back(sock); + sockets.push_back(detail::TCPSocket::connect( + RING_TAG, + address, + CONN_ATTEMPTS, + CONN_WAIT, + [verbose](int attempt, int wait) { + log_info( + verbose, + "Attempt", + attempt, + "waiting", + wait, + "ms (error:", + errno, + ")"); + })); } return sockets; } + template struct SumOp { void operator()(const T* input, T* output, size_t N) { @@ -540,7 +410,10 @@ struct MinOp { class RingGroup : public GroupImpl { public: - RingGroup(int rank, std::vector> nodes, bool verbose) + RingGroup( + int rank, + std::vector> nodes, + bool verbose) : rank_(rank), verbose_(verbose), pool_(0) { if (rank_ > 0 && rank_ >= nodes.size()) { throw std::runtime_error( diff --git a/mlx/distributed/utils.cpp b/mlx/distributed/utils.cpp new file mode 100644 index 000000000..1598694c2 --- /dev/null +++ b/mlx/distributed/utils.cpp @@ -0,0 +1,189 @@ +// Copyright © 2025 Apple Inc. + +#include +#include +#include +#include + +#include "mlx/distributed/utils.h" + +namespace mlx::core::distributed::detail { + +/** + * Parse a sockaddr from an ip and port provided as strings. + */ +address_t parse_address(const std::string& ip, const std::string& port) { + struct addrinfo hints, *res; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + + int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res); + if (status != 0) { + std::ostringstream msg; + msg << "Can't parse address " << ip << ":" << port; + throw std::runtime_error(msg.str()); + } + + address_t result; + memcpy(&result.addr, res->ai_addr, res->ai_addrlen); + result.len = res->ai_addrlen; + freeaddrinfo(res); + + return result; +} + +/** + * Parse a sockaddr provided as an : string. + */ +address_t parse_address(const std::string& ip_port) { + auto colon = ip_port.find(":"); + if (colon == std::string::npos) { + std::ostringstream msg; + msg << "Can't parse address " << ip_port; + throw std::runtime_error(msg.str()); + } + std::string ip(ip_port.begin(), ip_port.begin() + colon); + std::string port(ip_port.begin() + colon + 1, ip_port.end()); + + return parse_address(ip, port); +} + +TCPSocket::TCPSocket(const char* tag) { + sock_ = socket(AF_INET, SOCK_STREAM, 0); + if (sock_ < 0) { + std::ostringstream msg; + msg << tag << " Couldn't create socket (error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } +} + +TCPSocket::TCPSocket(TCPSocket&& s) { + sock_ = s.sock_; + s.sock_ = -1; +} + +TCPSocket::TCPSocket(int s) : sock_(s) {} + +TCPSocket::~TCPSocket() { + if (sock_ > 0) { + shutdown(sock_, 2); + close(sock_); + } +} + +void TCPSocket::listen(const char* tag, const address_t& addr) { + int success; + + // Make sure we can launch immediately after shutdown by setting the + // reuseaddr option so that we don't get address already in use errors + int enable = 1; + success = setsockopt(sock_, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)); + if (success < 0) { + std::ostringstream msg; + msg << tag << " Couldn't enable reuseaddr (error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + success = setsockopt(sock_, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int)); + if (success < 0) { + std::ostringstream msg; + msg << tag << " Couldn't enable reuseport (error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + + // Bind the socket to the address and port + success = bind(sock_, addr.get(), addr.len); + if (success < 0) { + std::ostringstream msg; + msg << tag << " Couldn't bind socket (error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + + // Prepare waiting for connections + success = ::listen(sock_, 0); + if (success < 0) { + std::ostringstream msg; + msg << tag << " Couldn't listen (error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } +} + +TCPSocket TCPSocket::accept(const char* tag) { + int peer = ::accept(sock_, nullptr, nullptr); + if (peer < 0) { + std::ostringstream msg; + msg << tag << " Accept failed (error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + + return TCPSocket(peer); +} + +void TCPSocket::send(const char* tag, const void* data, size_t len) { + while (len > 0) { + auto n = ::send(sock_, data, len, 0); + if (n <= 0) { + std::ostringstream msg; + msg << tag << " Send failed with errno=" << errno; + throw std::runtime_error(msg.str()); + } + len -= n; + data = static_cast(data) + n; + } +} + +void TCPSocket::recv(const char* tag, void* data, size_t len) { + while (len > 0) { + auto n = ::recv(sock_, data, len, 0); + if (n <= 0) { + std::ostringstream msg; + msg << tag << " Recv failed with errno=" << errno; + throw std::runtime_error(msg.str()); + } + len -= n; + data = static_cast(data) + n; + } +} + +TCPSocket TCPSocket::connect( + const char* tag, + const address_t& addr, + int num_retries, + int wait, + std::function cb) { + int sock, success; + + // Attempt to connect `num_retries` times with exponential backoff. + for (int attempt = 0; attempt < num_retries; attempt++) { + // Create the socket + sock = socket(AF_INET, SOCK_STREAM, 0); + if (sock < 0) { + std::ostringstream msg; + msg << tag << " Couldn't create socket to connect (error: " << errno + << ")"; + throw std::runtime_error(msg.str()); + } + + success = ::connect(sock, addr.get(), addr.len); + if (success == 0) { + break; + } + + cb(attempt, wait); + if (wait > 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(wait)); + } + + wait <<= 1; + } + + if (success < 0) { + std::ostringstream msg; + msg << tag << " Couldn't connect (error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + + return TCPSocket(sock); +} + +} // namespace mlx::core::distributed::detail diff --git a/mlx/distributed/utils.h b/mlx/distributed/utils.h new file mode 100644 index 000000000..ef01cf09a --- /dev/null +++ b/mlx/distributed/utils.h @@ -0,0 +1,62 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +namespace mlx::core::distributed::detail { + +struct address_t { + sockaddr_storage addr; + socklen_t len; + + const sockaddr* get() const { + return (struct sockaddr*)&addr; + } +}; + +/** + * Parse a sockaddr from an ip and port provided as strings. + */ +address_t parse_address(const std::string& ip, const std::string& port); + +/** + * Parse a sockaddr provided as an : string. + */ +address_t parse_address(const std::string& ip_port); + +/** + * Small wrapper over a TCP socket to simplify initiating connections. + */ +class TCPSocket { + public: + TCPSocket(const char* tag); + TCPSocket(const TCPSocket&) = delete; + TCPSocket& operator=(const TCPSocket&) = delete; + TCPSocket(TCPSocket&& s); + ~TCPSocket(); + + void listen(const char* tag, const address_t& addr); + TCPSocket accept(const char* tag); + + void send(const char* tag, const void* data, size_t len); + void recv(const char* tag, void* data, size_t len); + + operator int() const { + return sock_; + } + + static TCPSocket connect( + const char* tag, + const address_t& addr, + int num_retries = 1, + int wait = 0, + std::function cb = nullptr); + + private: + TCPSocket(int sock); + + int sock_; +}; + +} // namespace mlx::core::distributed::detail