From f3b605e53c03e417812bf04b192fdcf4dde8be42 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 17 Oct 2025 19:03:26 +0300 Subject: [PATCH] Add working reduce and semi-working all gather --- mlx/distributed/ibv/ibv.cpp | 347 +++++++++++++++--------------------- 1 file changed, 147 insertions(+), 200 deletions(-) diff --git a/mlx/distributed/ibv/ibv.cpp b/mlx/distributed/ibv/ibv.cpp index 1b655db78..2a4c4644b 100644 --- a/mlx/distributed/ibv/ibv.cpp +++ b/mlx/distributed/ibv/ibv.cpp @@ -70,9 +70,8 @@ struct Destination { }; std::ostream& operator<<(std::ostream& os, const Destination& dst) { - os << dst.local_id << " " << dst.queue_pair_number - << " " << dst.packet_sequence_number << " " - << dst.global_identifier; + os << dst.local_id << " " << dst.queue_pair_number << " " + << dst.packet_sequence_number << " " << dst.global_identifier; return os; } @@ -378,9 +377,10 @@ class SideChannel { sockets_.push_back(server.accept(IBV_TAG)); } - std::vector ranks(size-1); + std::vector ranks(size - 1); for (int i = 0; i < size - 1; i++) { - sockets_[i].recv(IBV_TAG, reinterpret_cast(&ranks[i]), sizeof(int)); + sockets_[i].recv( + IBV_TAG, reinterpret_cast(&ranks[i]), sizeof(int)); ranks[i]--; } for (int i = 0; i < size - 1; i++) { @@ -739,115 +739,75 @@ class IBVGroup : public GroupImpl { // Copy our data to the appropriate place std::memcpy(out_ptr + rank_ * n_bytes, in_ptr, n_bytes); + // Fully connected all gather + char* data = out_ptr; + char* our_data = out_ptr + rank_ * n_bytes; constexpr int64_t N = BUFFER_SIZE; + constexpr int PIPELINE = 2; int64_t total = static_cast(n_bytes); - int64_t offset = N; - int a = 0, b = 1; + int num_peers = size_ - 1; - int mask_init = 1 << rank_; - int mask_target = (1 << size_) - 1; + // Counters to maintain the state of transfers + int in_flight = 0; + int read_offset = 0; + int completed_send_count[PIPELINE] = {0}; + int write_offset[MAX_PEERS] = {0}; - post_recv_all(a); - std::copy( - in_ptr, - in_ptr + std::min(N, total), - cm_.send_buffer(a).begin()); - post_send_all(a); - - int mask_a_send = mask_init; - int mask_a_recv = mask_init; - int mask_b_send = mask_init; - int mask_b_recv = mask_init; - - while (offset < total) { - // While the previous chunk is in flight copy to the next send buffer + // Prefill the pipeline + int buff = 0; + while (read_offset < total && buff < PIPELINE) { + post_recv_all(buff); std::copy( - in_ptr + offset, - in_ptr + std::min(offset + N, total), - cm_.send_buffer(b).begin()); + our_data + read_offset, + our_data + std::min(read_offset + N, total), + cm_.send_buffer(buff).begin()); + post_send_all(buff); - // Send if the previous send is already done - for (int i = 0, m = 1; i < size_; i++, m *= 2) { - if (i == rank_) { - continue; - } - if (mask_a_send & m) { - cm_.send_to(i, b); - } - } - - // Recv the next buffer if the previous one is already done - for (int i = 0, m = 1; i < size_; i++, m *= 2) { - if (i == rank_) { - continue; - } - if (mask_a_recv & m) { - cm_.recv_from(i, b); - std::copy( - cm_.buffer(i, a).begin(), - cm_.buffer(i, a).begin() + std::min(N, total - offset), - out_ptr + i * n_bytes + std::max(offset - N, 0LL)); - } - } - - // Loop until this chunk is all done - while (mask_a_send != mask_target || mask_a_recv != mask_target) { - ibv_wc wc[8]; - int n = cm_.poll(8, wc); - for (int i = 0; i < n; i++) { - int work_type = wc[i].wr_id >> 16; - int buff = (wc[i].wr_id >> 8) & 0xff; - int rank = wc[i].wr_id & 0xff; - - if (work_type == SEND_WR) { - if (buff == a) { - cm_.send_to(rank, b); - mask_a_send |= 1 << rank; - } else { - mask_b_send |= 1 << rank; - } - } else { - if (buff == a) { - cm_.recv_from(rank, b); - mask_a_recv |= 1 << rank; - std::copy( - cm_.buffer(rank, a).begin(), - cm_.buffer(rank, a).begin() + - std::min(N, total - offset), - out_ptr + rank * n_bytes + std::max(offset - N, 0LL)); - } else { - mask_b_recv |= 1 << rank; - } - } - } - } - - std::swap(a, b); - mask_a_send = mask_b_send; - mask_a_recv = mask_b_recv; - mask_b_send = mask_init; - mask_b_recv = mask_init; - offset += N; + buff++; + in_flight += 2 * num_peers; + read_offset += N; } - { - while (mask_a_send != mask_target || mask_a_recv != mask_target) { - ibv_wc wc[8]; - int n = cm_.poll(8, wc); - for (int i = 0; i < n; i++) { - int work_type = wc[i].wr_id >> 16; - int buff = (wc[i].wr_id >> 8) & 0xff; - int rank = wc[i].wr_id & 0xff; + // Main loop + // + // Keep going until we have no longer data in flight. + while (in_flight > 0) { + ibv_wc wc[8]; + int n = cm_.poll(8, wc); + for (int i = 0; i < n; i++) { + int work_type = wc[i].wr_id >> 16; + int buff = (wc[i].wr_id >> 8) & 0xff; + int rank = wc[i].wr_id & 0xff; - if (work_type == SEND_WR) { - mask_a_send |= 1 << rank; - } else { - mask_a_recv |= 1 << rank; + in_flight--; + + // Send completed. If all sends completed then send the next chunk. + if (work_type == SEND_WR && read_offset < total) { + completed_send_count[buff]++; + if (completed_send_count[buff] == num_peers) { std::copy( - cm_.buffer(rank, a).begin(), - cm_.buffer(rank, a).begin() + - std::min(N, total + N - offset), - out_ptr + rank * n_bytes + offset - N); + our_data + read_offset, + our_data + std::min(read_offset + N, total), + cm_.send_buffer(buff).begin()); + post_send_all(buff); + + completed_send_count[buff] = 0; + in_flight += num_peers; + read_offset += N; + } + } + + // Recv completed. If we have more chunks then post another recv. + else if (work_type == RECV_WR) { + std::copy( + cm_.buffer(rank, buff).begin(), + cm_.buffer(rank, buff).begin() + + std::min(N, total - write_offset[rank]), + data + rank * n_bytes + write_offset[rank]); + write_offset[rank] += N; + if (write_offset[rank] + N * (PIPELINE - 1) < total) { + cm_.recv_from(rank, buff); + in_flight++; } } } @@ -902,112 +862,99 @@ class IBVGroup : public GroupImpl { // Fully connected all reduce T* data = out_ptr; constexpr int64_t N = BUFFER_SIZE / sizeof(T); + constexpr int PIPELINE = 2; int64_t total = static_cast(size); - int64_t offset = N; - int a = 0, b = 1; + int num_peers = size_ - 1; - int mask_init = 1 << rank_; - int mask_target = (1 << size_) - 1; + // Counters to maintain the state of transfers + int in_flight = 0; + int read_offset = 0; + int completed_send_count[PIPELINE] = {0}; + int completed_recv_begin[MAX_PEERS] = {0}; + int completed_recv_end[MAX_PEERS] = {0}; - // Handle the first piece of data. - post_recv_all(a); - std::copy(data, data + std::min(N, total), cm_.send_buffer(a).begin()); - post_send_all(a); - - int mask_a_send = mask_init; - int mask_a_recv = mask_init; - int mask_b_send = mask_init; - int mask_b_recv = mask_init; - - while (offset < total) { - // While the previous chunk is in flight copy to the next send buffer + // Prefill the pipeline + int buff = 0; + while (read_offset < total && buff < PIPELINE) { + post_recv_all(buff); std::copy( - data + offset, - data + std::min(offset + N, total), - cm_.send_buffer(b).begin()); + data + read_offset, + data + std::min(read_offset + N, total), + cm_.send_buffer(buff).begin()); + post_send_all(buff); - // Send if the previous send is already done - for (int i = 0, m = 1; i < size_; i++, m *= 2) { - if (i == rank_) { - continue; - } - if (mask_a_send & m) { - cm_.send_to(i, b); - } - } - - // Recv the next buffer if the previous one is already done - for (int i = 0, m = 1; i < size_; i++, m *= 2) { - if (i == rank_) { - continue; - } - if (mask_a_recv & m) { - cm_.recv_from(i, b); - reduce_op( - cm_.buffer(i, a).begin(), - data + std::max(offset - N, 0LL), - std::min(N, total - offset)); - } - } - - // Loop until this chunk is all done - while (mask_a_send != mask_target || mask_a_recv != mask_target) { - ibv_wc wc[8]; - int n = cm_.poll(8, wc); - for (int i = 0; i < n; i++) { - int work_type = wc[i].wr_id >> 16; - int buff = (wc[i].wr_id >> 8) & 0xff; - int rank = wc[i].wr_id & 0xff; - - if (work_type == SEND_WR) { - if (buff == a) { - cm_.send_to(rank, b); - mask_a_send |= 1 << rank; - } else { - mask_b_send |= 1 << rank; - } - } else { - if (buff == a) { - cm_.recv_from(rank, b); - mask_a_recv |= 1 << rank; - reduce_op( - cm_.buffer(rank, a).begin(), - data + std::max(offset - N, 0LL), - std::min(N, total - offset)); - } else { - mask_b_recv |= 1 << rank; - } - } - } - } - - std::swap(a, b); - mask_a_send = mask_b_send; - mask_a_recv = mask_b_recv; - mask_b_send = mask_init; - mask_b_recv = mask_init; - offset += N; + buff++; + in_flight += 2 * num_peers; + read_offset += N; } - { - while (mask_a_send != mask_target || mask_a_recv != mask_target) { - ibv_wc wc[8]; - int n = cm_.poll(8, wc); - for (int i = 0; i < n; i++) { - int work_type = wc[i].wr_id >> 16; - int buff = (wc[i].wr_id >> 8) & 0xff; - int rank = wc[i].wr_id & 0xff; + // Main loop + // + // Keep going until we have no longer data in flight. + while (in_flight > 0) { + // Poll the hardware for completions. + // + // If a send was completed mark how many completions we have received + // for that buffer. If we have sent the buffer to all peers we can + // reuse the buffer so copy the next chunk of data and send it to all. + // + // If a receive is completed then advance the pointer of completed + // receives. + ibv_wc wc[8]; + int n = cm_.poll(8, wc); + for (int i = 0; i < n; i++) { + int work_type = wc[i].wr_id >> 16; + int buff = (wc[i].wr_id >> 8) & 0xff; + int rank = wc[i].wr_id & 0xff; - if (work_type == SEND_WR) { - mask_a_send |= 1 << rank; - } else { - mask_a_recv |= 1 << rank; - reduce_op( - cm_.buffer(rank, a).begin(), - data + offset - N, - std::min(N, total + N - offset)); + in_flight--; + + if (work_type == SEND_WR && read_offset < total) { + completed_send_count[buff]++; + if (completed_send_count[buff] == num_peers) { + std::copy( + data + read_offset, + data + std::min(read_offset + N, total), + cm_.send_buffer(buff).begin()); + post_send_all(buff); + + completed_send_count[buff] = 0; + in_flight += num_peers; + read_offset += N; } } + + else if (work_type == RECV_WR) { + completed_recv_end[rank]++; + } + } + + // Process the completed recv + // + // For each rank we have a range of completed recv defined by a begin + // and end inclusive and exlusive in standard C++ fashion. + // + // When there is an unprocessed receive we first check if we have + // finished sending the write location. If so then we reduce in-place + // and then check if there is more to be received and post a recv. + for (int r = 0; r < size_; r++) { + int s = completed_recv_begin[r]; + int e = completed_recv_end[r]; + int w = s * N; + while (w < read_offset && e - s > 0) { + int buff = s % PIPELINE; + reduce_op( + cm_.buffer(r, buff).begin(), + data + w, + std::min(N, total - w)); + w += N; + s++; + if (w + (PIPELINE - 1) * N < total) { + cm_.recv_from(r, buff); + in_flight++; + } + } + completed_recv_begin[r] = s; } } });