All gather

This commit is contained in:
Angelos Katharopoulos
2025-10-01 01:36:59 -07:00
parent b1a60b2d2d
commit 4dbffb3954
3 changed files with 190 additions and 106 deletions

View File

@@ -1,14 +1,15 @@
// Copyright © 2025 Apple Inc.
#include <infiniband/verbs.h>
#include <unistd.h>
#include <fstream>
#include <iostream>
#include <unistd.h>
#include <json.hpp>
#include "mlx/backend/cpu/encoder.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/reduction_ops.h"
#include "mlx/distributed/utils.h"
#include "mlx/dtype_utils.h"
@@ -49,7 +50,7 @@ std::ostream& operator<<(std::ostream& os, const ibv_gid& gid) {
void* page_aligned_alloc(size_t num_bytes) {
static size_t page_size = sysconf(_SC_PAGESIZE);
void * buf;
void* buf;
if (posix_memalign(&buf, page_size, num_bytes)) {
return nullptr;
}
@@ -74,8 +75,7 @@ struct Destination {
class SharedBuffer {
public:
SharedBuffer(size_t num_bytes)
: data_(page_aligned_alloc(num_bytes)),
num_bytes_(num_bytes) {}
: data_(page_aligned_alloc(num_bytes)), num_bytes_(num_bytes) {}
~SharedBuffer() {
for (auto& [pd, mr] : memory_regions_) {
ibv_dereg_mr(mr);
@@ -87,8 +87,7 @@ class SharedBuffer {
SharedBuffer(const SharedBuffer&) = delete;
SharedBuffer& operator=(const SharedBuffer&) = delete;
SharedBuffer(SharedBuffer&& b)
: data_(nullptr), num_bytes_(0) {
SharedBuffer(SharedBuffer&& b) : data_(nullptr), num_bytes_(0) {
std::swap(data_, b.data_);
std::swap(num_bytes_, b.num_bytes_);
std::swap(memory_regions_, b.memory_regions_);
@@ -556,8 +555,7 @@ class ConnectionManager {
void send_to(int rank, int buff) {
connections_[rank].post_send(
buffers_[buff * size_ + rank_],
SEND_WR << 16 | buff << 8 | rank);
buffers_[buff * size_ + rank_], SEND_WR << 16 | buff << 8 | rank);
}
void recv_from(int rank, int buff) {
@@ -594,9 +592,7 @@ class ConnectionManager {
*/
int poll(int rank, int num_completions, ibv_wc* work_completions) {
return ibv_poll_cq(
connections_[rank].completion_queue,
num_completions,
work_completions);
connections_[rank].completion_queue, num_completions, work_completions);
}
SharedBuffer& send_buffer(int buff) {
@@ -646,17 +642,6 @@ class ConnectionManager {
std::vector<SharedBuffer> buffers_;
};
template <typename T>
struct SumOp {
void operator()(const T* input, T* output, size_t N) const {
while (N-- > 0) {
*output += *input;
input++;
output++;
}
}
};
std::vector<std::string> load_device_names(int rank, const char* dev_file) {
std::vector<std::string> device_names;
std::ifstream f(dev_file);
@@ -678,7 +663,8 @@ namespace mlx::core::distributed::ibv {
class IBVGroup : public GroupImpl {
public:
IBVGroup(ConnectionManager cm) : cm_(std::move(cm)), rank_(cm.rank()), size_(cm.size()) {}
IBVGroup(ConnectionManager cm)
: cm_(std::move(cm)), rank_(cm.rank()), size_(cm.size()) {}
Stream communication_stream(StreamOrDevice s) override {
return to_stream(s, Device::cpu);
@@ -695,24 +681,154 @@ class IBVGroup : public GroupImpl {
void all_sum(const array& input, array& output, Stream stream) override {
dispatch_all_types(output.dtype(), [&](auto type_tag) {
using T = MLX_GET_TYPE(type_tag);
all_reduce<T>(input, output, stream, SumOp<T>{});
all_reduce<T>(input, output, stream, detail::SumOp<T>{});
});
}
void all_max(const array& input, array& output, Stream stream) override {
dispatch_all_types(output.dtype(), [&](auto type_tag) {
using T = MLX_GET_TYPE(type_tag);
all_reduce<T>(input, output, stream, detail::MaxOp<T>{});
});
}
void all_min(const array& input, array& output, Stream stream) override {
dispatch_all_types(output.dtype(), [&](auto type_tag) {
using T = MLX_GET_TYPE(type_tag);
all_reduce<T>(input, output, stream, detail::MinOp<T>{});
});
}
void all_gather(const array& input, array& output, Stream stream) override {
auto in_ptr = input.data<char>();
auto out_ptr = output.data<char>();
size_t n_bytes = input.nbytes();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(input);
encoder.set_output_array(output);
encoder.dispatch([in_ptr, out_ptr, n_bytes, this]() {
// Copy our data to the appropriate place
std::memcpy(out_ptr + rank_ * n_bytes, in_ptr, n_bytes);
constexpr int64_t N = BUFFER_SIZE;
int64_t total = static_cast<int64_t>(n_bytes);
int64_t offset = N;
int a = 0, b = 1;
int mask_init = 1 << rank_;
int mask_target = (1 << size_) - 1;
post_recv_all(a);
std::copy(
in_ptr,
in_ptr + std::min(N, total),
cm_.send_buffer(a).begin<char>());
post_send_all(a);
int mask_a_send = mask_init;
int mask_a_recv = mask_init;
int mask_b_send = mask_init;
int mask_b_recv = mask_init;
while (offset < total) {
// While the previous chunk is in flight copy to the next send buffer
std::copy(
in_ptr + offset,
in_ptr + std::min(offset + N, total),
cm_.send_buffer(b).begin<char>());
// Send if the previous send is already done
for (int i = 0, m = 1; i < size_; i++, m *= 2) {
if (i == rank_) {
continue;
}
if (mask_a_send & m) {
cm_.send_to(i, b);
}
}
// Recv the next buffer if the previous one is already done
for (int i = 0, m = 1; i < size_; i++, m *= 2) {
if (i == rank_) {
continue;
}
if (mask_a_recv & m) {
cm_.recv_from(i, b);
std::copy(
cm_.buffer(i, a).begin<char>(),
cm_.buffer(i, a).begin<char>() + std::min(N, total - offset),
out_ptr + i * n_bytes + std::max(offset - N, 0LL));
}
}
// Loop until this chunk is all done
while (mask_a_send != mask_target || mask_a_recv != mask_target) {
ibv_wc wc[8];
int n = cm_.poll(8, wc);
for (int i = 0; i < n; i++) {
int work_type = wc[i].wr_id >> 16;
int buff = (wc[i].wr_id >> 8) & 0xff;
int rank = wc[i].wr_id & 0xff;
if (work_type == SEND_WR) {
if (buff == a) {
cm_.send_to(rank, b);
mask_a_send |= 1 << rank;
} else {
mask_b_send |= 1 << rank;
}
} else {
if (buff == a) {
cm_.recv_from(rank, b);
mask_a_recv |= 1 << rank;
std::copy(
cm_.buffer(rank, a).begin<char>(),
cm_.buffer(rank, a).begin<char>() +
std::min(N, total - offset),
out_ptr + rank * n_bytes + std::max(offset - N, 0LL));
} else {
mask_b_recv |= 1 << rank;
}
}
}
}
std::swap(a, b);
mask_a_send = mask_b_send;
mask_a_recv = mask_b_recv;
mask_b_send = mask_init;
mask_b_recv = mask_init;
offset += N;
}
{
while (mask_a_send != mask_target || mask_a_recv != mask_target) {
ibv_wc wc[8];
int n = cm_.poll(8, wc);
for (int i = 0; i < n; i++) {
int work_type = wc[i].wr_id >> 16;
int buff = (wc[i].wr_id >> 8) & 0xff;
int rank = wc[i].wr_id & 0xff;
if (work_type == SEND_WR) {
mask_a_send |= 1 << rank;
} else {
mask_a_recv |= 1 << rank;
std::copy(
cm_.buffer(rank, a).begin<char>(),
cm_.buffer(rank, a).begin<char>() +
std::min(N, total + N - offset),
out_ptr + rank * n_bytes + offset - N);
}
}
}
}
});
}
void send(const array& input, int dst, Stream stream) override {
}
void send(const array& input, int dst, Stream stream) override {}
void recv(array& out, int src, Stream stream) override {
}
void recv(array& out, int src, Stream stream) override {}
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
throw std::runtime_error("[ibv] Group split not supported.");
@@ -746,6 +862,7 @@ class IBVGroup : public GroupImpl {
auto in_ptr = input.data<T>();
auto out_ptr = output.data<T>();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(input);
encoder.set_output_array(output);
encoder.dispatch([in_ptr, out_ptr, size = input.size(), this, reduce_op]() {
// If not inplace all reduce then copy the input to the output first
@@ -765,10 +882,7 @@ class IBVGroup : public GroupImpl {
// Handle the first piece of data.
post_recv_all(a);
std::copy(
data,
data + std::min(N, total),
cm_.send_buffer(a).begin<T>());
std::copy(data, data + std::min(N, total), cm_.send_buffer(a).begin<T>());
post_send_all(a);
int mask_a_send = mask_init;
@@ -905,8 +1019,8 @@ std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
auto cm = ConnectionManager(rank, device_names, coordinator);
if (cm.size() > MAX_PEERS) {
std::ostringstream msg;
msg << "[ibv] The maximum number of supported peers is "
<< MAX_PEERS << " but " << cm.size() << " was provided";
msg << "[ibv] The maximum number of supported peers is " << MAX_PEERS
<< " but " << cm.size() << " was provided";
throw std::runtime_error(msg.str());
}
@@ -914,42 +1028,6 @@ std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
cm.barrier();
return std::make_shared<IBVGroup>(std::move(cm));
//cm.recv_from(rank ^ 1, 0);
//cm.barrier();
//std::fill(
// cm.buffer(rank, 0).begin<int>(),
// cm.buffer(rank, 0).end<int>(),
// 2 * rank + 1);
//cm.barrier();
//cm.send_to(rank ^ 1, 0);
//ibv_wc wc[8];
//bool sent = false, recvd = false;
//while (!(sent && recvd)) {
// int num_completions = cm.poll(8, wc);
// for (int i = 0; i < num_completions; i++) {
// if (wc[i].status != IBV_WC_SUCCESS) {
// std::cout << "Error " << wc[i].status << std::endl;
// }
// int work_type = wc[i].wr_id >> 16;
// int buff = (wc[i].wr_id >> 8) & 0xff;
// int rank = wc[i].wr_id & 0xff;
// std::cout << work_type << " " << buff << " " << rank << std::endl;
// print_wc(wc[i]);
// sent |= (work_type == SEND_WR);
// recvd |= (work_type == RECV_WR);
// }
//}
//std::cout << rank << " ours " << *cm.buffer(rank, 0).data<int>() << std::endl;
//std::cout << rank << " theirs " << *cm.buffer(rank ^ 1, 0).data<int>()
// << std::endl;
//return nullptr;
}
} // namespace mlx::core::distributed::ibv