From 4dbffb39545b659a086742a919477fa175964fab Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 1 Oct 2025 01:36:59 -0700 Subject: [PATCH] All gather --- mlx/distributed/ibv/ibv.cpp | 218 ++++++++++++++++++++++---------- mlx/distributed/reduction_ops.h | 38 ++++++ mlx/distributed/ring/ring.cpp | 40 +----- 3 files changed, 190 insertions(+), 106 deletions(-) create mode 100644 mlx/distributed/reduction_ops.h diff --git a/mlx/distributed/ibv/ibv.cpp b/mlx/distributed/ibv/ibv.cpp index 4b829a182..37e5d1113 100644 --- a/mlx/distributed/ibv/ibv.cpp +++ b/mlx/distributed/ibv/ibv.cpp @@ -1,14 +1,15 @@ // Copyright © 2025 Apple Inc. #include +#include #include #include -#include #include #include "mlx/backend/cpu/encoder.h" #include "mlx/distributed/distributed_impl.h" +#include "mlx/distributed/reduction_ops.h" #include "mlx/distributed/utils.h" #include "mlx/dtype_utils.h" @@ -49,7 +50,7 @@ std::ostream& operator<<(std::ostream& os, const ibv_gid& gid) { void* page_aligned_alloc(size_t num_bytes) { static size_t page_size = sysconf(_SC_PAGESIZE); - void * buf; + void* buf; if (posix_memalign(&buf, page_size, num_bytes)) { return nullptr; } @@ -74,8 +75,7 @@ struct Destination { class SharedBuffer { public: SharedBuffer(size_t num_bytes) - : data_(page_aligned_alloc(num_bytes)), - num_bytes_(num_bytes) {} + : data_(page_aligned_alloc(num_bytes)), num_bytes_(num_bytes) {} ~SharedBuffer() { for (auto& [pd, mr] : memory_regions_) { ibv_dereg_mr(mr); @@ -87,8 +87,7 @@ class SharedBuffer { SharedBuffer(const SharedBuffer&) = delete; SharedBuffer& operator=(const SharedBuffer&) = delete; - SharedBuffer(SharedBuffer&& b) - : data_(nullptr), num_bytes_(0) { + SharedBuffer(SharedBuffer&& b) : data_(nullptr), num_bytes_(0) { std::swap(data_, b.data_); std::swap(num_bytes_, b.num_bytes_); std::swap(memory_regions_, b.memory_regions_); @@ -556,8 +555,7 @@ class ConnectionManager { void send_to(int rank, int buff) { connections_[rank].post_send( - buffers_[buff * size_ + rank_], - SEND_WR << 16 | buff << 8 | rank); + buffers_[buff * size_ + rank_], SEND_WR << 16 | buff << 8 | rank); } void recv_from(int rank, int buff) { @@ -594,9 +592,7 @@ class ConnectionManager { */ int poll(int rank, int num_completions, ibv_wc* work_completions) { return ibv_poll_cq( - connections_[rank].completion_queue, - num_completions, - work_completions); + connections_[rank].completion_queue, num_completions, work_completions); } SharedBuffer& send_buffer(int buff) { @@ -646,17 +642,6 @@ class ConnectionManager { std::vector buffers_; }; -template -struct SumOp { - void operator()(const T* input, T* output, size_t N) const { - while (N-- > 0) { - *output += *input; - input++; - output++; - } - } -}; - std::vector load_device_names(int rank, const char* dev_file) { std::vector device_names; std::ifstream f(dev_file); @@ -678,7 +663,8 @@ namespace mlx::core::distributed::ibv { class IBVGroup : public GroupImpl { public: - IBVGroup(ConnectionManager cm) : cm_(std::move(cm)), rank_(cm.rank()), size_(cm.size()) {} + IBVGroup(ConnectionManager cm) + : cm_(std::move(cm)), rank_(cm.rank()), size_(cm.size()) {} Stream communication_stream(StreamOrDevice s) override { return to_stream(s, Device::cpu); @@ -695,24 +681,154 @@ class IBVGroup : public GroupImpl { void all_sum(const array& input, array& output, Stream stream) override { dispatch_all_types(output.dtype(), [&](auto type_tag) { using T = MLX_GET_TYPE(type_tag); - all_reduce(input, output, stream, SumOp{}); + all_reduce(input, output, stream, detail::SumOp{}); }); } void all_max(const array& input, array& output, Stream stream) override { + dispatch_all_types(output.dtype(), [&](auto type_tag) { + using T = MLX_GET_TYPE(type_tag); + all_reduce(input, output, stream, detail::MaxOp{}); + }); } void all_min(const array& input, array& output, Stream stream) override { + dispatch_all_types(output.dtype(), [&](auto type_tag) { + using T = MLX_GET_TYPE(type_tag); + all_reduce(input, output, stream, detail::MinOp{}); + }); } void all_gather(const array& input, array& output, Stream stream) override { + auto in_ptr = input.data(); + auto out_ptr = output.data(); + size_t n_bytes = input.nbytes(); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(input); + encoder.set_output_array(output); + encoder.dispatch([in_ptr, out_ptr, n_bytes, this]() { + // Copy our data to the appropriate place + std::memcpy(out_ptr + rank_ * n_bytes, in_ptr, n_bytes); + + constexpr int64_t N = BUFFER_SIZE; + int64_t total = static_cast(n_bytes); + int64_t offset = N; + int a = 0, b = 1; + + int mask_init = 1 << rank_; + int mask_target = (1 << size_) - 1; + + post_recv_all(a); + std::copy( + in_ptr, + in_ptr + std::min(N, total), + cm_.send_buffer(a).begin()); + post_send_all(a); + + int mask_a_send = mask_init; + int mask_a_recv = mask_init; + int mask_b_send = mask_init; + int mask_b_recv = mask_init; + + while (offset < total) { + // While the previous chunk is in flight copy to the next send buffer + std::copy( + in_ptr + offset, + in_ptr + std::min(offset + N, total), + cm_.send_buffer(b).begin()); + + // Send if the previous send is already done + for (int i = 0, m = 1; i < size_; i++, m *= 2) { + if (i == rank_) { + continue; + } + if (mask_a_send & m) { + cm_.send_to(i, b); + } + } + + // Recv the next buffer if the previous one is already done + for (int i = 0, m = 1; i < size_; i++, m *= 2) { + if (i == rank_) { + continue; + } + if (mask_a_recv & m) { + cm_.recv_from(i, b); + std::copy( + cm_.buffer(i, a).begin(), + cm_.buffer(i, a).begin() + std::min(N, total - offset), + out_ptr + i * n_bytes + std::max(offset - N, 0LL)); + } + } + + // Loop until this chunk is all done + while (mask_a_send != mask_target || mask_a_recv != mask_target) { + ibv_wc wc[8]; + int n = cm_.poll(8, wc); + for (int i = 0; i < n; i++) { + int work_type = wc[i].wr_id >> 16; + int buff = (wc[i].wr_id >> 8) & 0xff; + int rank = wc[i].wr_id & 0xff; + + if (work_type == SEND_WR) { + if (buff == a) { + cm_.send_to(rank, b); + mask_a_send |= 1 << rank; + } else { + mask_b_send |= 1 << rank; + } + } else { + if (buff == a) { + cm_.recv_from(rank, b); + mask_a_recv |= 1 << rank; + std::copy( + cm_.buffer(rank, a).begin(), + cm_.buffer(rank, a).begin() + + std::min(N, total - offset), + out_ptr + rank * n_bytes + std::max(offset - N, 0LL)); + } else { + mask_b_recv |= 1 << rank; + } + } + } + } + + std::swap(a, b); + mask_a_send = mask_b_send; + mask_a_recv = mask_b_recv; + mask_b_send = mask_init; + mask_b_recv = mask_init; + offset += N; + } + + { + while (mask_a_send != mask_target || mask_a_recv != mask_target) { + ibv_wc wc[8]; + int n = cm_.poll(8, wc); + for (int i = 0; i < n; i++) { + int work_type = wc[i].wr_id >> 16; + int buff = (wc[i].wr_id >> 8) & 0xff; + int rank = wc[i].wr_id & 0xff; + + if (work_type == SEND_WR) { + mask_a_send |= 1 << rank; + } else { + mask_a_recv |= 1 << rank; + std::copy( + cm_.buffer(rank, a).begin(), + cm_.buffer(rank, a).begin() + + std::min(N, total + N - offset), + out_ptr + rank * n_bytes + offset - N); + } + } + } + } + }); } - void send(const array& input, int dst, Stream stream) override { - } + void send(const array& input, int dst, Stream stream) override {} - void recv(array& out, int src, Stream stream) override { - } + void recv(array& out, int src, Stream stream) override {} std::shared_ptr split(int color, int key = -1) override { throw std::runtime_error("[ibv] Group split not supported."); @@ -746,6 +862,7 @@ class IBVGroup : public GroupImpl { auto in_ptr = input.data(); auto out_ptr = output.data(); auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(input); encoder.set_output_array(output); encoder.dispatch([in_ptr, out_ptr, size = input.size(), this, reduce_op]() { // If not inplace all reduce then copy the input to the output first @@ -765,10 +882,7 @@ class IBVGroup : public GroupImpl { // Handle the first piece of data. post_recv_all(a); - std::copy( - data, - data + std::min(N, total), - cm_.send_buffer(a).begin()); + std::copy(data, data + std::min(N, total), cm_.send_buffer(a).begin()); post_send_all(a); int mask_a_send = mask_init; @@ -905,8 +1019,8 @@ std::shared_ptr init(bool strict /* = false */) { auto cm = ConnectionManager(rank, device_names, coordinator); if (cm.size() > MAX_PEERS) { std::ostringstream msg; - msg << "[ibv] The maximum number of supported peers is " - << MAX_PEERS << " but " << cm.size() << " was provided"; + msg << "[ibv] The maximum number of supported peers is " << MAX_PEERS + << " but " << cm.size() << " was provided"; throw std::runtime_error(msg.str()); } @@ -914,42 +1028,6 @@ std::shared_ptr init(bool strict /* = false */) { cm.barrier(); return std::make_shared(std::move(cm)); - - //cm.recv_from(rank ^ 1, 0); - //cm.barrier(); - //std::fill( - // cm.buffer(rank, 0).begin(), - // cm.buffer(rank, 0).end(), - // 2 * rank + 1); - //cm.barrier(); - //cm.send_to(rank ^ 1, 0); - - //ibv_wc wc[8]; - //bool sent = false, recvd = false; - //while (!(sent && recvd)) { - // int num_completions = cm.poll(8, wc); - // for (int i = 0; i < num_completions; i++) { - // if (wc[i].status != IBV_WC_SUCCESS) { - // std::cout << "Error " << wc[i].status << std::endl; - // } - - // int work_type = wc[i].wr_id >> 16; - // int buff = (wc[i].wr_id >> 8) & 0xff; - // int rank = wc[i].wr_id & 0xff; - - // std::cout << work_type << " " << buff << " " << rank << std::endl; - // print_wc(wc[i]); - - // sent |= (work_type == SEND_WR); - // recvd |= (work_type == RECV_WR); - // } - //} - - //std::cout << rank << " ours " << *cm.buffer(rank, 0).data() << std::endl; - //std::cout << rank << " theirs " << *cm.buffer(rank ^ 1, 0).data() - // << std::endl; - - //return nullptr; } } // namespace mlx::core::distributed::ibv diff --git a/mlx/distributed/reduction_ops.h b/mlx/distributed/reduction_ops.h new file mode 100644 index 000000000..02777be39 --- /dev/null +++ b/mlx/distributed/reduction_ops.h @@ -0,0 +1,38 @@ +// Copyright © 2025 Apple Inc. + +namespace mlx::core::distributed::detail { + +template +struct SumOp { + void operator()(const T* input, T* output, size_t N) const { + while (N-- > 0) { + *output += *input; + input++; + output++; + } + } +}; + +template +struct MaxOp { + void operator()(const T* input, T* output, size_t N) const { + while (N-- > 0) { + *output = std::max(*output, *input); + input++; + output++; + } + } +}; + +template +struct MinOp { + void operator()(const T* input, T* output, size_t N) const { + while (N-- > 0) { + *output = std::min(*output, *input); + input++; + output++; + } + } +}; + +} // namespace mlx::core::distributed::detail diff --git a/mlx/distributed/ring/ring.cpp b/mlx/distributed/ring/ring.cpp index b3f293458..c1275737a 100644 --- a/mlx/distributed/ring/ring.cpp +++ b/mlx/distributed/ring/ring.cpp @@ -19,6 +19,7 @@ #include "mlx/backend/cpu/encoder.h" #include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed_impl.h" +#include "mlx/distributed/reduction_ops.h" #include "mlx/distributed/utils.h" #include "mlx/threadpool.h" @@ -373,39 +374,6 @@ std::vector make_connections( return sockets; } -template -struct SumOp { - void operator()(const T* input, T* output, size_t N) { - while (N-- > 0) { - *output += *input; - input++; - output++; - } - } -}; - -template -struct MaxOp { - void operator()(const T* input, T* output, size_t N) { - while (N-- > 0) { - *output = std::max(*output, *input); - input++; - output++; - } - } -}; - -template -struct MinOp { - void operator()(const T* input, T* output, size_t N) { - while (N-- > 0) { - *output = std::min(*output, *input); - input++; - output++; - } - } -}; - } // namespace class RingGroup : public GroupImpl { @@ -506,17 +474,17 @@ class RingGroup : public GroupImpl { void all_sum(const array& input, array& output, Stream stream) override { SWITCH_TYPE( - output, all_reduce>(input, output, stream, SumOp())); + output, all_reduce(input, output, stream, detail::SumOp())); } void all_max(const array& input, array& output, Stream stream) override { SWITCH_TYPE( - output, all_reduce>(input, output, stream, MaxOp())); + output, all_reduce(input, output, stream, detail::MaxOp())); } void all_min(const array& input, array& output, Stream stream) override { SWITCH_TYPE( - output, all_reduce>(input, output, stream, MinOp())); + output, all_reduce(input, output, stream, detail::MinOp())); } std::shared_ptr split(int color, int key = -1) override {