mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
All gather
This commit is contained in:
@@ -1,14 +1,15 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <infiniband/verbs.h>
|
||||
#include <unistd.h>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <json.hpp>
|
||||
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/distributed/reduction_ops.h"
|
||||
#include "mlx/distributed/utils.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
@@ -49,7 +50,7 @@ std::ostream& operator<<(std::ostream& os, const ibv_gid& gid) {
|
||||
|
||||
void* page_aligned_alloc(size_t num_bytes) {
|
||||
static size_t page_size = sysconf(_SC_PAGESIZE);
|
||||
void * buf;
|
||||
void* buf;
|
||||
if (posix_memalign(&buf, page_size, num_bytes)) {
|
||||
return nullptr;
|
||||
}
|
||||
@@ -74,8 +75,7 @@ struct Destination {
|
||||
class SharedBuffer {
|
||||
public:
|
||||
SharedBuffer(size_t num_bytes)
|
||||
: data_(page_aligned_alloc(num_bytes)),
|
||||
num_bytes_(num_bytes) {}
|
||||
: data_(page_aligned_alloc(num_bytes)), num_bytes_(num_bytes) {}
|
||||
~SharedBuffer() {
|
||||
for (auto& [pd, mr] : memory_regions_) {
|
||||
ibv_dereg_mr(mr);
|
||||
@@ -87,8 +87,7 @@ class SharedBuffer {
|
||||
|
||||
SharedBuffer(const SharedBuffer&) = delete;
|
||||
SharedBuffer& operator=(const SharedBuffer&) = delete;
|
||||
SharedBuffer(SharedBuffer&& b)
|
||||
: data_(nullptr), num_bytes_(0) {
|
||||
SharedBuffer(SharedBuffer&& b) : data_(nullptr), num_bytes_(0) {
|
||||
std::swap(data_, b.data_);
|
||||
std::swap(num_bytes_, b.num_bytes_);
|
||||
std::swap(memory_regions_, b.memory_regions_);
|
||||
@@ -556,8 +555,7 @@ class ConnectionManager {
|
||||
|
||||
void send_to(int rank, int buff) {
|
||||
connections_[rank].post_send(
|
||||
buffers_[buff * size_ + rank_],
|
||||
SEND_WR << 16 | buff << 8 | rank);
|
||||
buffers_[buff * size_ + rank_], SEND_WR << 16 | buff << 8 | rank);
|
||||
}
|
||||
|
||||
void recv_from(int rank, int buff) {
|
||||
@@ -594,9 +592,7 @@ class ConnectionManager {
|
||||
*/
|
||||
int poll(int rank, int num_completions, ibv_wc* work_completions) {
|
||||
return ibv_poll_cq(
|
||||
connections_[rank].completion_queue,
|
||||
num_completions,
|
||||
work_completions);
|
||||
connections_[rank].completion_queue, num_completions, work_completions);
|
||||
}
|
||||
|
||||
SharedBuffer& send_buffer(int buff) {
|
||||
@@ -646,17 +642,6 @@ class ConnectionManager {
|
||||
std::vector<SharedBuffer> buffers_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct SumOp {
|
||||
void operator()(const T* input, T* output, size_t N) const {
|
||||
while (N-- > 0) {
|
||||
*output += *input;
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<std::string> load_device_names(int rank, const char* dev_file) {
|
||||
std::vector<std::string> device_names;
|
||||
std::ifstream f(dev_file);
|
||||
@@ -678,7 +663,8 @@ namespace mlx::core::distributed::ibv {
|
||||
|
||||
class IBVGroup : public GroupImpl {
|
||||
public:
|
||||
IBVGroup(ConnectionManager cm) : cm_(std::move(cm)), rank_(cm.rank()), size_(cm.size()) {}
|
||||
IBVGroup(ConnectionManager cm)
|
||||
: cm_(std::move(cm)), rank_(cm.rank()), size_(cm.size()) {}
|
||||
|
||||
Stream communication_stream(StreamOrDevice s) override {
|
||||
return to_stream(s, Device::cpu);
|
||||
@@ -695,24 +681,154 @@ class IBVGroup : public GroupImpl {
|
||||
void all_sum(const array& input, array& output, Stream stream) override {
|
||||
dispatch_all_types(output.dtype(), [&](auto type_tag) {
|
||||
using T = MLX_GET_TYPE(type_tag);
|
||||
all_reduce<T>(input, output, stream, SumOp<T>{});
|
||||
all_reduce<T>(input, output, stream, detail::SumOp<T>{});
|
||||
});
|
||||
}
|
||||
|
||||
void all_max(const array& input, array& output, Stream stream) override {
|
||||
dispatch_all_types(output.dtype(), [&](auto type_tag) {
|
||||
using T = MLX_GET_TYPE(type_tag);
|
||||
all_reduce<T>(input, output, stream, detail::MaxOp<T>{});
|
||||
});
|
||||
}
|
||||
|
||||
void all_min(const array& input, array& output, Stream stream) override {
|
||||
dispatch_all_types(output.dtype(), [&](auto type_tag) {
|
||||
using T = MLX_GET_TYPE(type_tag);
|
||||
all_reduce<T>(input, output, stream, detail::MinOp<T>{});
|
||||
});
|
||||
}
|
||||
|
||||
void all_gather(const array& input, array& output, Stream stream) override {
|
||||
auto in_ptr = input.data<char>();
|
||||
auto out_ptr = output.data<char>();
|
||||
size_t n_bytes = input.nbytes();
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(input);
|
||||
encoder.set_output_array(output);
|
||||
encoder.dispatch([in_ptr, out_ptr, n_bytes, this]() {
|
||||
// Copy our data to the appropriate place
|
||||
std::memcpy(out_ptr + rank_ * n_bytes, in_ptr, n_bytes);
|
||||
|
||||
constexpr int64_t N = BUFFER_SIZE;
|
||||
int64_t total = static_cast<int64_t>(n_bytes);
|
||||
int64_t offset = N;
|
||||
int a = 0, b = 1;
|
||||
|
||||
int mask_init = 1 << rank_;
|
||||
int mask_target = (1 << size_) - 1;
|
||||
|
||||
post_recv_all(a);
|
||||
std::copy(
|
||||
in_ptr,
|
||||
in_ptr + std::min(N, total),
|
||||
cm_.send_buffer(a).begin<char>());
|
||||
post_send_all(a);
|
||||
|
||||
int mask_a_send = mask_init;
|
||||
int mask_a_recv = mask_init;
|
||||
int mask_b_send = mask_init;
|
||||
int mask_b_recv = mask_init;
|
||||
|
||||
while (offset < total) {
|
||||
// While the previous chunk is in flight copy to the next send buffer
|
||||
std::copy(
|
||||
in_ptr + offset,
|
||||
in_ptr + std::min(offset + N, total),
|
||||
cm_.send_buffer(b).begin<char>());
|
||||
|
||||
// Send if the previous send is already done
|
||||
for (int i = 0, m = 1; i < size_; i++, m *= 2) {
|
||||
if (i == rank_) {
|
||||
continue;
|
||||
}
|
||||
if (mask_a_send & m) {
|
||||
cm_.send_to(i, b);
|
||||
}
|
||||
}
|
||||
|
||||
// Recv the next buffer if the previous one is already done
|
||||
for (int i = 0, m = 1; i < size_; i++, m *= 2) {
|
||||
if (i == rank_) {
|
||||
continue;
|
||||
}
|
||||
if (mask_a_recv & m) {
|
||||
cm_.recv_from(i, b);
|
||||
std::copy(
|
||||
cm_.buffer(i, a).begin<char>(),
|
||||
cm_.buffer(i, a).begin<char>() + std::min(N, total - offset),
|
||||
out_ptr + i * n_bytes + std::max(offset - N, 0LL));
|
||||
}
|
||||
}
|
||||
|
||||
// Loop until this chunk is all done
|
||||
while (mask_a_send != mask_target || mask_a_recv != mask_target) {
|
||||
ibv_wc wc[8];
|
||||
int n = cm_.poll(8, wc);
|
||||
for (int i = 0; i < n; i++) {
|
||||
int work_type = wc[i].wr_id >> 16;
|
||||
int buff = (wc[i].wr_id >> 8) & 0xff;
|
||||
int rank = wc[i].wr_id & 0xff;
|
||||
|
||||
if (work_type == SEND_WR) {
|
||||
if (buff == a) {
|
||||
cm_.send_to(rank, b);
|
||||
mask_a_send |= 1 << rank;
|
||||
} else {
|
||||
mask_b_send |= 1 << rank;
|
||||
}
|
||||
} else {
|
||||
if (buff == a) {
|
||||
cm_.recv_from(rank, b);
|
||||
mask_a_recv |= 1 << rank;
|
||||
std::copy(
|
||||
cm_.buffer(rank, a).begin<char>(),
|
||||
cm_.buffer(rank, a).begin<char>() +
|
||||
std::min(N, total - offset),
|
||||
out_ptr + rank * n_bytes + std::max(offset - N, 0LL));
|
||||
} else {
|
||||
mask_b_recv |= 1 << rank;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::swap(a, b);
|
||||
mask_a_send = mask_b_send;
|
||||
mask_a_recv = mask_b_recv;
|
||||
mask_b_send = mask_init;
|
||||
mask_b_recv = mask_init;
|
||||
offset += N;
|
||||
}
|
||||
|
||||
{
|
||||
while (mask_a_send != mask_target || mask_a_recv != mask_target) {
|
||||
ibv_wc wc[8];
|
||||
int n = cm_.poll(8, wc);
|
||||
for (int i = 0; i < n; i++) {
|
||||
int work_type = wc[i].wr_id >> 16;
|
||||
int buff = (wc[i].wr_id >> 8) & 0xff;
|
||||
int rank = wc[i].wr_id & 0xff;
|
||||
|
||||
if (work_type == SEND_WR) {
|
||||
mask_a_send |= 1 << rank;
|
||||
} else {
|
||||
mask_a_recv |= 1 << rank;
|
||||
std::copy(
|
||||
cm_.buffer(rank, a).begin<char>(),
|
||||
cm_.buffer(rank, a).begin<char>() +
|
||||
std::min(N, total + N - offset),
|
||||
out_ptr + rank * n_bytes + offset - N);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void send(const array& input, int dst, Stream stream) override {
|
||||
}
|
||||
void send(const array& input, int dst, Stream stream) override {}
|
||||
|
||||
void recv(array& out, int src, Stream stream) override {
|
||||
}
|
||||
void recv(array& out, int src, Stream stream) override {}
|
||||
|
||||
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||
throw std::runtime_error("[ibv] Group split not supported.");
|
||||
@@ -746,6 +862,7 @@ class IBVGroup : public GroupImpl {
|
||||
auto in_ptr = input.data<T>();
|
||||
auto out_ptr = output.data<T>();
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(input);
|
||||
encoder.set_output_array(output);
|
||||
encoder.dispatch([in_ptr, out_ptr, size = input.size(), this, reduce_op]() {
|
||||
// If not inplace all reduce then copy the input to the output first
|
||||
@@ -765,10 +882,7 @@ class IBVGroup : public GroupImpl {
|
||||
|
||||
// Handle the first piece of data.
|
||||
post_recv_all(a);
|
||||
std::copy(
|
||||
data,
|
||||
data + std::min(N, total),
|
||||
cm_.send_buffer(a).begin<T>());
|
||||
std::copy(data, data + std::min(N, total), cm_.send_buffer(a).begin<T>());
|
||||
post_send_all(a);
|
||||
|
||||
int mask_a_send = mask_init;
|
||||
@@ -905,8 +1019,8 @@ std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||
auto cm = ConnectionManager(rank, device_names, coordinator);
|
||||
if (cm.size() > MAX_PEERS) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ibv] The maximum number of supported peers is "
|
||||
<< MAX_PEERS << " but " << cm.size() << " was provided";
|
||||
msg << "[ibv] The maximum number of supported peers is " << MAX_PEERS
|
||||
<< " but " << cm.size() << " was provided";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
@@ -914,42 +1028,6 @@ std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||
cm.barrier();
|
||||
|
||||
return std::make_shared<IBVGroup>(std::move(cm));
|
||||
|
||||
//cm.recv_from(rank ^ 1, 0);
|
||||
//cm.barrier();
|
||||
//std::fill(
|
||||
// cm.buffer(rank, 0).begin<int>(),
|
||||
// cm.buffer(rank, 0).end<int>(),
|
||||
// 2 * rank + 1);
|
||||
//cm.barrier();
|
||||
//cm.send_to(rank ^ 1, 0);
|
||||
|
||||
//ibv_wc wc[8];
|
||||
//bool sent = false, recvd = false;
|
||||
//while (!(sent && recvd)) {
|
||||
// int num_completions = cm.poll(8, wc);
|
||||
// for (int i = 0; i < num_completions; i++) {
|
||||
// if (wc[i].status != IBV_WC_SUCCESS) {
|
||||
// std::cout << "Error " << wc[i].status << std::endl;
|
||||
// }
|
||||
|
||||
// int work_type = wc[i].wr_id >> 16;
|
||||
// int buff = (wc[i].wr_id >> 8) & 0xff;
|
||||
// int rank = wc[i].wr_id & 0xff;
|
||||
|
||||
// std::cout << work_type << " " << buff << " " << rank << std::endl;
|
||||
// print_wc(wc[i]);
|
||||
|
||||
// sent |= (work_type == SEND_WR);
|
||||
// recvd |= (work_type == RECV_WR);
|
||||
// }
|
||||
//}
|
||||
|
||||
//std::cout << rank << " ours " << *cm.buffer(rank, 0).data<int>() << std::endl;
|
||||
//std::cout << rank << " theirs " << *cm.buffer(rank ^ 1, 0).data<int>()
|
||||
// << std::endl;
|
||||
|
||||
//return nullptr;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed::ibv
|
||||
|
||||
38
mlx/distributed/reduction_ops.h
Normal file
38
mlx/distributed/reduction_ops.h
Normal file
@@ -0,0 +1,38 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
namespace mlx::core::distributed::detail {
|
||||
|
||||
template <typename T>
|
||||
struct SumOp {
|
||||
void operator()(const T* input, T* output, size_t N) const {
|
||||
while (N-- > 0) {
|
||||
*output += *input;
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MaxOp {
|
||||
void operator()(const T* input, T* output, size_t N) const {
|
||||
while (N-- > 0) {
|
||||
*output = std::max(*output, *input);
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MinOp {
|
||||
void operator()(const T* input, T* output, size_t N) const {
|
||||
while (N-- > 0) {
|
||||
*output = std::min(*output, *input);
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::distributed::detail
|
||||
@@ -19,6 +19,7 @@
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/distributed/reduction_ops.h"
|
||||
#include "mlx/distributed/utils.h"
|
||||
#include "mlx/threadpool.h"
|
||||
|
||||
@@ -373,39 +374,6 @@ std::vector<int> make_connections(
|
||||
return sockets;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct SumOp {
|
||||
void operator()(const T* input, T* output, size_t N) {
|
||||
while (N-- > 0) {
|
||||
*output += *input;
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MaxOp {
|
||||
void operator()(const T* input, T* output, size_t N) {
|
||||
while (N-- > 0) {
|
||||
*output = std::max(*output, *input);
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MinOp {
|
||||
void operator()(const T* input, T* output, size_t N) {
|
||||
while (N-- > 0) {
|
||||
*output = std::min(*output, *input);
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
class RingGroup : public GroupImpl {
|
||||
@@ -506,17 +474,17 @@ class RingGroup : public GroupImpl {
|
||||
|
||||
void all_sum(const array& input, array& output, Stream stream) override {
|
||||
SWITCH_TYPE(
|
||||
output, all_reduce<T, SumOp<T>>(input, output, stream, SumOp<T>()));
|
||||
output, all_reduce<T>(input, output, stream, detail::SumOp<T>()));
|
||||
}
|
||||
|
||||
void all_max(const array& input, array& output, Stream stream) override {
|
||||
SWITCH_TYPE(
|
||||
output, all_reduce<T, MaxOp<T>>(input, output, stream, MaxOp<T>()));
|
||||
output, all_reduce<T>(input, output, stream, detail::MaxOp<T>()));
|
||||
}
|
||||
|
||||
void all_min(const array& input, array& output, Stream stream) override {
|
||||
SWITCH_TYPE(
|
||||
output, all_reduce<T, MinOp<T>>(input, output, stream, MinOp<T>()));
|
||||
output, all_reduce<T>(input, output, stream, detail::MinOp<T>()));
|
||||
}
|
||||
|
||||
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||
|
||||
Reference in New Issue
Block a user