Initial working all reduce

This commit is contained in:
Angelos Katharopoulos
2025-09-09 13:32:06 -07:00
parent 27232db1ba
commit 67e454ab0a
9 changed files with 1273 additions and 163 deletions

View File

@@ -119,6 +119,10 @@ if(MLX_BUILD_METAL)
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_SDK_VERSION
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-path"
OUTPUT_VARIABLE CMAKE_OSX_SYSROOT
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_SDK_VERSION} LESS 14.0)
message(

View File

@@ -4,6 +4,11 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp)
if(MLX_BUILD_CPU AND NOT WIN32)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ibv)

View File

@@ -5,6 +5,7 @@
#include "mlx/backend/cuda/cuda.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/ibv/ibv.h"
#include "mlx/distributed/mpi/mpi.h"
#include "mlx/distributed/nccl/nccl.h"
#include "mlx/distributed/ring/ring.h"
@@ -102,7 +103,8 @@ class EmptyGroup : public GroupImpl {
} // namespace detail
bool is_available() {
return mpi::is_available() || ring::is_available() || nccl::is_available();
return mpi::is_available() || ring::is_available() || nccl::is_available() ||
ibv::is_available();
}
int Group::rank() const {
@@ -135,6 +137,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
group = ring::init(strict);
} else if (bk == "nccl") {
group = nccl::init(strict);
} else if (bk == "ibv") {
group = ibv::init(strict);
} else if (bk == "any") {
if (mlx::core::cu::is_available()) {
group = nccl::init(false);
@@ -148,13 +152,17 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
group = mpi::init(false);
bk_ = "mpi";
}
if (group == nullptr) {
group = ibv::init(false);
bk_ = "ibv";
}
if (group == nullptr && strict) {
throw std::runtime_error("[distributed] Couldn't initialize any backend");
}
} else {
std::ostringstream msg;
msg << "[distributed] The only valid values for backend are 'any', 'mpi' "
<< "and 'ring' but '" << bk << "' was provided.";
msg << "[distributed] The only valid values for backend are 'any', 'mpi', 'nccl', "
<< "'ibv' and 'ring' but '" << bk << "' was provided.";
throw std::invalid_argument(msg.str());
}

View File

@@ -0,0 +1,2 @@
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ibv.cpp)
target_link_libraries(mlx PRIVATE rdma)

955
mlx/distributed/ibv/ibv.cpp Normal file
View File

@@ -0,0 +1,955 @@
// Copyright © 2025 Apple Inc.
#include <infiniband/verbs.h>
#include <fstream>
#include <iostream>
#include <unistd.h>
#include <json.hpp>
#include "mlx/backend/cpu/encoder.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/utils.h"
#include "mlx/dtype_utils.h"
constexpr const char* IBV_TAG = "[ibv]";
constexpr int NUM_BUFFERS = 2;
constexpr int BUFFER_SIZE = 4096;
constexpr int MAX_SEND_WR = 32;
constexpr int MAX_RECV_WR = 32;
constexpr int SEND_WR = 1;
constexpr int RECV_WR = 2;
constexpr int MAX_PEERS = 8;
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
using json = nlohmann::json;
namespace detail = mlx::core::distributed::detail;
namespace allocator = mlx::core::allocator;
template <typename T, typename = void>
struct is_container : std::false_type {};
template <typename T>
struct is_container<
T,
std::void_t<typename T::value_type, typename T::iterator>>
: std::true_type {};
std::ostream& operator<<(std::ostream& os, const ibv_gid& gid) {
os << std::hex << std::setfill('0');
for (int i = 0; i < 16; i += 2) {
uint16_t part = (gid.raw[i] << 8) | gid.raw[i + 1];
os << std::setw(4) << part;
if (i < 14)
os << ":";
}
os << std::dec;
return os;
}
void* page_aligned_alloc(size_t num_bytes) {
static size_t page_size = sysconf(_SC_PAGESIZE);
void * buf;
if (posix_memalign(&buf, page_size, num_bytes)) {
return nullptr;
}
return buf;
}
/**
* Contains the information that defines a destination to a remote device.
* Basically we can compute our own destination and share it with remote hosts
* over the side channel.
*/
struct Destination {
int local_id;
int queue_pair_number;
int packet_sequence_number;
ibv_gid global_identifier;
};
/**
* A buffer that can be registered to a number of protection domains.
*/
class SharedBuffer {
public:
SharedBuffer(size_t num_bytes)
: data_(page_aligned_alloc(num_bytes)),
num_bytes_(num_bytes) {}
~SharedBuffer() {
for (auto& [pd, mr] : memory_regions_) {
ibv_dereg_mr(mr);
}
if (data_ != nullptr) {
std::free(data_);
}
}
SharedBuffer(const SharedBuffer&) = delete;
SharedBuffer& operator=(const SharedBuffer&) = delete;
SharedBuffer(SharedBuffer&& b)
: data_(nullptr), num_bytes_(0) {
std::swap(data_, b.data_);
std::swap(num_bytes_, b.num_bytes_);
std::swap(memory_regions_, b.memory_regions_);
}
void register_to_protection_domain(ibv_pd* protection_domain) {
auto [it, inserted] = memory_regions_.insert({protection_domain, nullptr});
if (!inserted) {
throw std::runtime_error(
"[ibv] Buffer can be registered once per protection domain");
}
it->second = ibv_reg_mr(
protection_domain,
data_,
num_bytes_,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ |
IBV_ACCESS_REMOTE_WRITE);
if (!it->second) {
throw std::runtime_error("[ibv] Register memory region failed");
}
}
size_t size() const {
return num_bytes_;
}
uint32_t local_key(ibv_pd* protection_domain) const {
return memory_regions_.at(protection_domain)->lkey;
}
ibv_sge to_scatter_gather_entry(ibv_pd* protection_domain) const {
ibv_sge entry;
entry.addr = reinterpret_cast<uintptr_t>(data_);
entry.length = size();
entry.lkey = local_key(protection_domain);
return entry;
}
template <typename T>
T* data() {
return static_cast<T*>(data_);
}
template <typename T>
T* begin() {
return static_cast<T*>(data_);
}
template <typename T>
T* end() {
return static_cast<T*>(data_) + size() / sizeof(T);
}
private:
void* data_;
size_t num_bytes_;
std::unordered_map<ibv_pd*, ibv_mr*> memory_regions_;
};
/**
* Manipulates an RDMA connection. Enables (among other things)
*
* - Creating a queue pair
* - Sending and receiving
* - Checking completion
*/
struct Connection {
ibv_context* ctx;
ibv_pd* protection_domain;
ibv_cq* completion_queue;
ibv_qp* queue_pair;
Destination src; // holds the local information
Connection(ibv_context* ctx_)
: ctx(ctx_),
protection_domain(nullptr),
completion_queue(nullptr),
queue_pair(nullptr) {
src.local_id = -1;
}
Connection(Connection&& c) : Connection(nullptr) {
std::swap(ctx, c.ctx);
std::swap(protection_domain, c.protection_domain);
std::swap(completion_queue, c.completion_queue);
std::swap(queue_pair, c.queue_pair);
std::swap(src, c.src);
}
Connection(const Connection&) = delete;
Connection& operator=(Connection&) = delete;
~Connection() {
if (queue_pair != nullptr) {
ibv_destroy_qp(queue_pair);
}
if (completion_queue != nullptr) {
ibv_destroy_cq(completion_queue);
}
if (protection_domain != nullptr) {
ibv_dealloc_pd(protection_domain);
}
if (ctx != nullptr) {
ibv_close_device(ctx);
}
}
void allocate_protection_domain() {
protection_domain = ibv_alloc_pd(ctx);
if (protection_domain == nullptr) {
throw std::runtime_error("[ibv] Couldn't allocate protection domain");
}
}
void create_completion_queue(int num_entries) {
completion_queue = ibv_create_cq(ctx, num_entries, nullptr, nullptr, 0);
if (completion_queue == nullptr) {
throw std::runtime_error("[ibv] Couldn't create completion queue");
}
}
void create_queue_pair() {
ibv_qp_init_attr init_attr;
init_attr.qp_context = ctx;
init_attr.qp_context = ctx;
init_attr.send_cq = completion_queue;
init_attr.recv_cq = completion_queue;
init_attr.srq = nullptr;
init_attr.cap.max_send_wr = MAX_SEND_WR;
init_attr.cap.max_recv_wr = MAX_RECV_WR;
init_attr.cap.max_send_sge = 1;
init_attr.cap.max_recv_sge = 1;
init_attr.cap.max_inline_data = 0;
init_attr.qp_type = IBV_QPT_UC;
init_attr.sq_sig_all = 0;
queue_pair = ibv_create_qp(protection_domain, &init_attr);
if (queue_pair == nullptr) {
throw std::runtime_error("[ibv] Couldn't create queue pair");
}
}
const Destination& info() {
if (queue_pair == nullptr || src.local_id >= 0) {
return src;
}
ibv_port_attr port_attr;
ibv_query_port(ctx, 1, &port_attr);
ibv_gid gid;
ibv_query_gid(ctx, 1, 1, &gid);
src.local_id = port_attr.lid;
src.queue_pair_number = queue_pair->qp_num;
src.packet_sequence_number = 7; // TODO: Change to sth random
src.global_identifier = gid;
return src;
}
void queue_pair_init() {
ibv_qp_attr attr = {};
attr.qp_state = IBV_QPS_INIT;
attr.port_num = 1;
attr.pkey_index = 0;
attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ |
IBV_ACCESS_REMOTE_WRITE;
int mask =
IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS;
if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) {
std::ostringstream msg;
msg << "[ibv] Changing queue pair to INIT failed with errno " << status;
throw std::invalid_argument(msg.str());
}
}
void queue_pair_rtr(const Destination& dst) {
ibv_qp_attr attr = {};
memset(&attr, 0, sizeof(attr));
attr.qp_state = IBV_QPS_RTR;
attr.path_mtu = IBV_MTU_1024;
attr.rq_psn = dst.packet_sequence_number;
attr.dest_qp_num = dst.queue_pair_number;
attr.ah_attr.dlid = dst.local_id;
attr.ah_attr.sl = 0;
attr.ah_attr.src_path_bits = 0;
attr.ah_attr.port_num = 1;
attr.ah_attr.is_global = 0;
if (dst.global_identifier.global.interface_id) {
attr.ah_attr.is_global = 1;
attr.ah_attr.grh.hop_limit = 1;
attr.ah_attr.grh.dgid = dst.global_identifier;
attr.ah_attr.grh.sgid_index = 1;
}
int mask = IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN |
IBV_QP_RQ_PSN;
if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) {
std::ostringstream msg;
msg << "[ibv] Changing queue pair to RTR failed with errno " << status;
throw std::invalid_argument(msg.str());
}
}
void queue_pair_rts() {
ibv_qp_attr attr = {};
attr.qp_state = IBV_QPS_RTS;
attr.sq_psn = src.packet_sequence_number;
int mask = IBV_QP_STATE | IBV_QP_SQ_PSN;
if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) {
std::ostringstream msg;
msg << "[ibv] Changing queue pair to RTS failed with errno " << status;
throw std::invalid_argument(msg.str());
}
}
void post_send(const SharedBuffer& buff, uint64_t work_request_id) {
ibv_send_wr work_request, *bad_work_request;
auto entry = buff.to_scatter_gather_entry(protection_domain);
work_request.wr_id = work_request_id;
work_request.sg_list = &entry;
work_request.num_sge = 1;
work_request.opcode = IBV_WR_SEND;
work_request.send_flags = IBV_SEND_SIGNALED;
work_request.next = nullptr;
if (int status =
ibv_post_send(queue_pair, &work_request, &bad_work_request);
status != 0) {
std::ostringstream msg;
msg << "[ibv] Send failed with error code " << status;
throw std::invalid_argument(msg.str());
}
}
void post_recv(const SharedBuffer& buff, uint64_t work_request_id) {
ibv_recv_wr work_request, *bad_work_request;
auto entry = buff.to_scatter_gather_entry(protection_domain);
work_request.wr_id = work_request_id;
work_request.sg_list = &entry;
work_request.num_sge = 1;
work_request.next = nullptr;
if (int status =
ibv_post_recv(queue_pair, &work_request, &bad_work_request);
status != 0) {
std::ostringstream msg;
msg << "[ibv] Recv failed with error code " << status;
throw std::invalid_argument(msg.str());
}
}
};
/**
* Implement a TCP side channel to exchange information about the RDMA
* connections.
*
* Implements a simple all gather where every node sends to rank 0 and rank 0
* broadcasts to every node.
*/
class SideChannel {
public:
SideChannel(int rank, int size, const char* addr) : rank_(rank), size_(size) {
auto address = detail::parse_address(addr);
if (rank_ == 0) {
detail::TCPSocket server(IBV_TAG);
server.listen(IBV_TAG, address);
for (int i = 0; i < size - 1; i++) {
sockets_.push_back(server.accept(IBV_TAG));
}
} else {
sockets_.push_back(detail::TCPSocket::connect(
IBV_TAG, address, 4, 1000, [](int attempt, int wait) {
std::cerr << IBV_TAG << " Connection attempt " << attempt
<< " waiting " << wait << " ms" << std::endl;
}));
}
}
SideChannel(const SideChannel&) = delete;
SideChannel& operator=(const SideChannel&) = delete;
SideChannel(SideChannel&& sc)
: rank_(sc.rank_), size_(sc.size_), sockets_(std::move(sc.sockets_)) {
sc.rank_ = -1;
sc.size_ = -1;
}
template <typename T>
std::vector<T> all_gather(const T& v) {
std::vector<T> result(size_);
// T is a container of stuff like std::vector or std::string
if constexpr (is_container<T>::value) {
using U = typename T::value_type;
// Share the lengths first and set the communication size to be the
// maximum length of the containers.
auto lengths = all_gather<int>(v.size());
auto max_len = *std::max_element(lengths.begin(), lengths.end());
for (auto& s : result) {
s.resize(max_len);
}
// All gather of length max_len
if (rank_ == 0) {
std::copy(v.begin(), v.end(), result[rank_].begin());
for (int i = 1; i < size_; i++) {
sockets_[i - 1].recv(IBV_TAG, result[i].data(), sizeof(U) * max_len);
}
for (int i = 1; i < size_; i++) {
for (int j = 0; j < size_; j++) {
sockets_[i - 1].send(
IBV_TAG, result[j].data(), sizeof(U) * max_len);
}
}
} else {
std::copy(v.begin(), v.end(), result[rank_].begin());
sockets_[0].send(IBV_TAG, result[rank_].data(), sizeof(U) * max_len);
for (int i = 0; i < size_; i++) {
sockets_[0].recv(IBV_TAG, result[i].data(), sizeof(U) * max_len);
}
}
// Resize the outputs back to the original length
for (int i = 0; i < size_; i++) {
result[i].resize(lengths[i]);
}
}
// T is a scalar
else {
if (rank_ == 0) {
result[rank_] = v;
for (int i = 1; i < size_; i++) {
sockets_[i - 1].recv(IBV_TAG, &result[i], sizeof(T));
}
for (int i = 1; i < size_; i++) {
sockets_[i - 1].send(IBV_TAG, result.data(), size_ * sizeof(T));
}
} else {
sockets_[0].send(IBV_TAG, &v, sizeof(T));
sockets_[0].recv(IBV_TAG, result.data(), size_ * sizeof(T));
}
}
return result;
}
private:
int rank_;
int size_;
std::vector<detail::TCPSocket> sockets_;
};
/**
* Manages a set of connections. Among other things it uses a side channel to
* exchange the necessary information and then configure the connections to be
* ready for RDMA operations.
*/
class ConnectionManager {
public:
ConnectionManager(
int rank,
const std::vector<std::string>& device_names,
const char* coordinator_addr)
: rank_(rank),
size_(device_names.size()),
side_channel_(rank_, size_, coordinator_addr) {
create_contexts(device_names);
if (connections_[rank_].ctx != nullptr) {
throw std::runtime_error("[ibv] Malformed device file");
}
}
int rank() const {
return rank_;
}
int size() const {
return size_;
}
/**
* Performs the connection initialization. Namely, after this call all
* Connection objects should have a queue pair in RTS state.
*/
void initialize(int num_buffers, size_t num_bytes) {
// Create the queue pairs
for (auto& conn : connections_) {
if (conn.ctx == nullptr) {
continue;
}
conn.allocate_protection_domain();
conn.create_completion_queue(MAX_SEND_WR + MAX_RECV_WR);
conn.create_queue_pair();
}
allocate_buffers(num_buffers, num_bytes);
// Gather the information to be exchanged
std::vector<Destination> info;
for (auto& conn : connections_) {
info.emplace_back(conn.info());
}
auto all_infos = side_channel_.all_gather(info);
// Transition queue pairs to RTS
for (int peer = 0; peer < size_; peer++) {
if (peer == rank_) {
continue;
}
auto peer_info = all_infos[peer][rank_];
connections_[peer].queue_pair_init();
connections_[peer].queue_pair_rtr(peer_info);
connections_[peer].queue_pair_rts();
}
}
void allocate_buffers(int num_buffers, size_t num_bytes) {
// Deregister any buffers and free the memory
buffers_.clear();
// Allocate the memory
for (int i = 0; i < num_buffers; i++) {
for (int j = 0; j < size_; j++) {
buffers_.emplace_back(num_bytes);
}
}
for (int i = 0; i < num_buffers; i++) {
for (int j = 0; j < size_; j++) {
// This is our send buffer so register it with all pds so we can send
// it to all connected devices.
if (j == rank_) {
for (auto& conn : connections_) {
if (conn.ctx != nullptr) {
buffers_[i * size_ + j].register_to_protection_domain(
conn.protection_domain);
}
}
}
// This is the recv buffer from rank j so register it to rank j's
// protection domain.
else {
buffers_[i * size_ + j].register_to_protection_domain(
connections_[j].protection_domain);
}
}
}
}
void send_to(int rank, int buff) {
connections_[rank].post_send(
buffers_[buff * size_ + rank_],
SEND_WR << 16 | buff << 8 | rank);
}
void recv_from(int rank, int buff) {
connections_[rank].post_recv(
buffers_[buff * size_ + rank], RECV_WR << 16 | buff << 8 | rank);
}
/**
* Poll all connections and save the work completions and return the
* corresponding length.
*/
int poll(int num_completions, ibv_wc* work_completions) {
int completions = 0;
for (int r = 0; r < size_; r++) {
if (r == rank_) {
continue;
}
if (completions >= num_completions) {
return completions;
}
int c = ibv_poll_cq(
connections_[r].completion_queue,
num_completions - completions,
work_completions + completions);
completions += c;
}
return completions;
}
/**
*
*/
int poll(int rank, int num_completions, ibv_wc* work_completions) {
return ibv_poll_cq(
connections_[rank].completion_queue,
num_completions,
work_completions);
}
SharedBuffer& send_buffer(int buff) {
return buffers_[buff * size_ + rank_];
}
SharedBuffer& buffer(int rank, int buff) {
return buffers_[buff * size_ + rank];
}
void barrier() {
side_channel_.all_gather<int>(0);
}
private:
void create_contexts(const std::vector<std::string>& device_names) {
int num_devices = 0;
ibv_device** devices = ibv_get_device_list(&num_devices);
for (auto& name : device_names) {
// Empty so add a nullptr context
if (name.empty()) {
connections_.emplace_back(nullptr);
continue;
}
// Search for the name and try to open the device
for (int i = 0; i < num_devices; i++) {
if (name == ibv_get_device_name(devices[i])) {
auto ctx = ibv_open_device(devices[i]);
if (ctx == nullptr) {
std::ostringstream msg;
msg << "[ibv] Could not open device " << name;
throw std::runtime_error(msg.str());
}
connections_.emplace_back(ctx);
break;
}
}
}
ibv_free_device_list(devices);
}
int rank_;
int size_;
SideChannel side_channel_;
std::vector<Connection> connections_;
std::vector<SharedBuffer> buffers_;
};
template <typename T>
struct SumOp {
void operator()(const T* input, T* output, size_t N) const {
while (N-- > 0) {
*output += *input;
input++;
output++;
}
}
};
std::vector<std::string> load_device_names(int rank, const char* dev_file) {
std::vector<std::string> device_names;
std::ifstream f(dev_file);
json devices = json::parse(f);
devices = devices[rank];
for (auto it = devices.begin(); it != devices.end(); it++) {
std::string n;
if (!it->is_null()) {
n = *it;
}
device_names.emplace_back(std::move(n));
}
return device_names;
}
namespace mlx::core::distributed::ibv {
class IBVGroup : public GroupImpl {
public:
IBVGroup(ConnectionManager cm) : cm_(std::move(cm)), rank_(cm.rank()), size_(cm.size()) {}
Stream communication_stream(StreamOrDevice s) override {
return to_stream(s, Device::cpu);
}
int rank() override {
return cm_.rank();
}
int size() override {
return cm_.size();
}
void all_sum(const array& input, array& output, Stream stream) override {
dispatch_all_types(output.dtype(), [&](auto type_tag) {
using T = MLX_GET_TYPE(type_tag);
all_reduce<T>(input, output, stream, SumOp<T>{});
});
}
void all_max(const array& input, array& output, Stream stream) override {
}
void all_min(const array& input, array& output, Stream stream) override {
}
void all_gather(const array& input, array& output, Stream stream) override {
}
void send(const array& input, int dst, Stream stream) override {
}
void recv(array& out, int src, Stream stream) override {
}
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
throw std::runtime_error("[ibv] Group split not supported.");
}
private:
void post_recv_all(int buffer) {
for (int i = 0; i < size_; i++) {
if (i == rank_) {
continue;
}
cm_.recv_from(i, buffer);
}
}
void post_send_all(int buffer) {
for (int i = 0; i < size_; i++) {
if (i == rank_) {
continue;
}
cm_.send_to(i, buffer);
}
}
template <typename T, typename ReduceOp>
void all_reduce(
const array& input,
array& output,
Stream stream,
ReduceOp reduce_op) {
auto in_ptr = input.data<T>();
auto out_ptr = output.data<T>();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(output);
encoder.dispatch([in_ptr, out_ptr, size = input.size(), this, reduce_op]() {
// If not inplace all reduce then copy the input to the output first
if (in_ptr != out_ptr) {
std::memcpy(out_ptr, in_ptr, size * sizeof(T));
}
// Fully connected all reduce
T* data = out_ptr;
constexpr int64_t N = BUFFER_SIZE / sizeof(T);
int64_t total = static_cast<int64_t>(size);
int64_t offset = N;
int a = 0, b = 1;
int mask_init = 1 << rank_;
int mask_target = (1 << size_) - 1;
// Handle the first piece of data.
post_recv_all(a);
std::copy(
data,
data + std::min(N, total),
cm_.send_buffer(a).begin<T>());
post_send_all(a);
int mask_a_send = mask_init;
int mask_a_recv = mask_init;
int mask_b_send = mask_init;
int mask_b_recv = mask_init;
while (offset < total) {
// While the previous chunk is in flight copy to the next send buffer
std::copy(
data + offset,
data + std::min(offset + N, total),
cm_.send_buffer(b).begin<T>());
// Send if the previous send is already done
for (int i = 0, m = 1; i < size_; i++, m *= 2) {
if (i == rank_) {
continue;
}
if (mask_a_send & m) {
cm_.send_to(i, b);
}
}
// Recv the next buffer if the previous one is already done
for (int i = 0, m = 1; i < size_; i++, m *= 2) {
if (i == rank_) {
continue;
}
if (mask_a_recv & m) {
cm_.recv_from(i, b);
reduce_op(
cm_.buffer(i, a).begin<T>(),
data + std::max(offset - N, 0LL),
std::min(N, total - offset));
}
}
// Loop until this chunk is all done
while (mask_a_send != mask_target || mask_a_recv != mask_target) {
ibv_wc wc[8];
int n = cm_.poll(8, wc);
for (int i = 0; i < n; i++) {
int work_type = wc[i].wr_id >> 16;
int buff = (wc[i].wr_id >> 8) & 0xff;
int rank = wc[i].wr_id & 0xff;
if (work_type == SEND_WR) {
if (buff == a) {
cm_.send_to(rank, b);
mask_a_send |= 1 << rank;
} else {
mask_b_send |= 1 << rank;
}
} else {
if (buff == a) {
cm_.recv_from(rank, b);
mask_a_recv |= 1 << rank;
reduce_op(
cm_.buffer(rank, a).begin<T>(),
data + std::max(offset - N, 0LL),
std::min(N, total - offset));
} else {
mask_b_recv |= 1 << rank;
}
}
}
}
std::swap(a, b);
mask_a_send = mask_b_send;
mask_a_recv = mask_b_recv;
mask_b_send = mask_init;
mask_b_recv = mask_init;
offset += N;
}
{
while (mask_a_send != mask_target || mask_a_recv != mask_target) {
ibv_wc wc[8];
int n = cm_.poll(8, wc);
for (int i = 0; i < n; i++) {
int work_type = wc[i].wr_id >> 16;
int buff = (wc[i].wr_id >> 8) & 0xff;
int rank = wc[i].wr_id & 0xff;
if (work_type == SEND_WR) {
mask_a_send |= 1 << rank;
} else {
mask_a_recv |= 1 << rank;
reduce_op(
cm_.buffer(rank, a).begin<T>(),
data + offset - N,
std::min(N, total + N - offset));
}
}
}
}
});
}
ConnectionManager cm_;
int rank_;
int size_;
};
bool is_available() {
return true;
}
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
const char* dev_file = std::getenv("MLX_IBV_DEVICES");
const char* coordinator = std::getenv("MLX_IBV_COORDINATOR");
const char* rank_str = std::getenv("MLX_RANK");
const char* ring_verbose = std::getenv("MLX_IBV_VERBOSE");
if (!dev_file || !coordinator || !rank_str) {
if (strict) {
std::ostringstream msg;
msg << "[ibv] You need to provide via environment variables a rank (MLX_RANK), "
<< "a device file (MLX_IBV_DEVICES) and a coordinator ip/port (MLX_IBV_COORDINATOR) "
<< "but provided MLX_RANK=\"" << ((rank_str) ? rank_str : "")
<< "\", MLX_IBV_DEVICES=\"" << ((dev_file) ? dev_file : "")
<< "\" and MLX_IBV_COORDINATOR=\""
<< ((coordinator) ? coordinator : "");
throw std::runtime_error(msg.str());
}
return nullptr;
}
auto rank = std::atoi(rank_str);
auto device_names = load_device_names(rank, dev_file);
auto cm = ConnectionManager(rank, device_names, coordinator);
if (cm.size() > MAX_PEERS) {
std::ostringstream msg;
msg << "[ibv] The maximum number of supported peers is "
<< MAX_PEERS << " but " << cm.size() << " was provided";
throw std::runtime_error(msg.str());
}
cm.initialize(NUM_BUFFERS, BUFFER_SIZE);
cm.barrier();
return std::make_shared<IBVGroup>(std::move(cm));
//cm.recv_from(rank ^ 1, 0);
//cm.barrier();
//std::fill(
// cm.buffer(rank, 0).begin<int>(),
// cm.buffer(rank, 0).end<int>(),
// 2 * rank + 1);
//cm.barrier();
//cm.send_to(rank ^ 1, 0);
//ibv_wc wc[8];
//bool sent = false, recvd = false;
//while (!(sent && recvd)) {
// int num_completions = cm.poll(8, wc);
// for (int i = 0; i < num_completions; i++) {
// if (wc[i].status != IBV_WC_SUCCESS) {
// std::cout << "Error " << wc[i].status << std::endl;
// }
// int work_type = wc[i].wr_id >> 16;
// int buff = (wc[i].wr_id >> 8) & 0xff;
// int rank = wc[i].wr_id & 0xff;
// std::cout << work_type << " " << buff << " " << rank << std::endl;
// print_wc(wc[i]);
// sent |= (work_type == SEND_WR);
// recvd |= (work_type == RECV_WR);
// }
//}
//std::cout << rank << " ours " << *cm.buffer(rank, 0).data<int>() << std::endl;
//std::cout << rank << " theirs " << *cm.buffer(rank ^ 1, 0).data<int>()
// << std::endl;
//return nullptr;
}
} // namespace mlx::core::distributed::ibv

12
mlx/distributed/ibv/ibv.h Normal file
View File

@@ -0,0 +1,12 @@
// Copyright © 2025 Apple Inc.
#include "mlx/distributed/distributed.h"
namespace mlx::core::distributed::ibv {
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
bool is_available();
std::shared_ptr<GroupImpl> init(bool strict = false);
} // namespace mlx::core::distributed::ibv

View File

@@ -1,9 +1,6 @@
// Copyright © 2024 Apple Inc.
#include <arpa/inet.h>
#include <fcntl.h>
#include <netdb.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <sys/socket.h>
#include <unistd.h>
@@ -22,6 +19,7 @@
#include "mlx/backend/cpu/encoder.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/utils.h"
#include "mlx/threadpool.h"
#ifndef SOL_TCP
@@ -94,6 +92,7 @@ constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024;
constexpr const size_t ALL_SUM_BUFFERS = 2;
constexpr const int CONN_ATTEMPTS = 5;
constexpr const int CONN_WAIT = 1000;
constexpr const char* RING_TAG = "[ring]";
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
using json = nlohmann::json;
@@ -296,55 +295,6 @@ class CommunicationThreads {
std::unordered_map<int, SocketThread> threads_;
};
struct address_t {
sockaddr_storage addr;
socklen_t len;
const sockaddr* get() const {
return (struct sockaddr*)&addr;
}
};
/**
* Parse a sockaddr from an ip and port provided as strings.
*/
address_t parse_address(const std::string& ip, const std::string& port) {
struct addrinfo hints, *res;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
if (status != 0) {
std::ostringstream msg;
msg << "Can't parse address " << ip << ":" << port;
throw std::runtime_error(msg.str());
}
address_t result;
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
result.len = res->ai_addrlen;
freeaddrinfo(res);
return result;
}
/**
* Parse a sockaddr provided as an <ip>:<port> string.
*/
address_t parse_address(const std::string& ip_port) {
auto colon = ip_port.find(":");
if (colon == std::string::npos) {
std::ostringstream msg;
msg << "Can't parse address " << ip_port;
throw std::runtime_error(msg.str());
}
std::string ip(ip_port.begin(), ip_port.begin() + colon);
std::string port(ip_port.begin() + colon + 1, ip_port.end());
return parse_address(ip, port);
}
/**
* Load all addresses from the json hostfile. The hostfile is a list of
* addresses in order of rank. For each rank there can be many addresses so
@@ -357,15 +307,15 @@ address_t parse_address(const std::string& ip_port) {
* ["ip3:5000", "ip3:5001"],
* ]
*/
std::vector<std::vector<address_t>> load_nodes(const char* hostfile) {
std::vector<std::vector<address_t>> nodes;
std::vector<std::vector<detail::address_t>> load_nodes(const char* hostfile) {
std::vector<std::vector<detail::address_t>> nodes;
std::ifstream f(hostfile);
json hosts = json::parse(f);
for (auto& h : hosts) {
std::vector<address_t> host;
std::vector<detail::address_t> host;
for (auto& ips : h) {
host.push_back(parse_address(ips.get<std::string>()));
host.push_back(std::move(detail::parse_address(ips.get<std::string>())));
}
nodes.push_back(std::move(host));
}
@@ -377,73 +327,15 @@ std::vector<std::vector<address_t>> load_nodes(const char* hostfile) {
* Create a socket and accept one connection for each of the provided
* addresses.
*/
std::vector<int> accept_connections(const std::vector<address_t>& addresses) {
std::vector<int> accept_connections(
const std::vector<detail::address_t>& addresses) {
std::vector<int> sockets;
int success;
for (auto& address : addresses) {
// Create the socket to wait for connections from the peers
int sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << "[ring] Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Make sure we can launch immediately after shutdown by setting the
// reuseaddr option so that we don't get address already in use errors
int enable = 1;
success = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Couldn't enable reuseaddr (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
success = setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Couldn't enable reuseport (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Bind the socket to the address and port
success = bind(sock, address.get(), address.len);
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Couldn't bind socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Wait for connections
success = listen(sock, 0);
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Couldn't listen (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
int peer_socket = accept(sock, nullptr, nullptr);
if (peer_socket < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Accept failed (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Close the listening socket
shutdown(sock, 2);
close(sock);
sockets.push_back(peer_socket);
detail::TCPSocket socket(RING_TAG);
socket.listen(RING_TAG, address);
sockets.push_back(socket.accept(RING_TAG));
}
return sockets;
@@ -454,55 +346,33 @@ std::vector<int> accept_connections(const std::vector<address_t>& addresses) {
* provided addresses.
*/
std::vector<int> make_connections(
const std::vector<address_t>& addresses,
const std::vector<detail::address_t>& addresses,
bool verbose) {
std::vector<int> sockets;
int success;
for (auto& address : addresses) {
int sock;
// Attempt to connect to the peer CONN_ATTEMPTS times with exponential
// backoff. TODO: Do we need that?
for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) {
// Create the socket
sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << "[ring] Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
if (attempt > 0) {
int wait = (1 << (attempt - 1)) * CONN_WAIT;
log_info(
verbose,
"Attempt",
attempt,
"wait",
wait,
"ms (error:",
errno,
")");
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
}
success = connect(sock, address.get(), address.len);
if (success == 0) {
break;
}
}
if (success < 0) {
std::ostringstream msg;
msg << "[ring] Couldn't connect (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
sockets.push_back(sock);
sockets.push_back(detail::TCPSocket::connect(
RING_TAG,
address,
CONN_ATTEMPTS,
CONN_WAIT,
[verbose](int attempt, int wait) {
log_info(
verbose,
"Attempt",
attempt,
"waiting",
wait,
"ms (error:",
errno,
")");
}));
}
return sockets;
}
template <typename T>
struct SumOp {
void operator()(const T* input, T* output, size_t N) {
@@ -540,7 +410,10 @@ struct MinOp {
class RingGroup : public GroupImpl {
public:
RingGroup(int rank, std::vector<std::vector<address_t>> nodes, bool verbose)
RingGroup(
int rank,
std::vector<std::vector<detail::address_t>> nodes,
bool verbose)
: rank_(rank), verbose_(verbose), pool_(0) {
if (rank_ > 0 && rank_ >= nodes.size()) {
throw std::runtime_error(

189
mlx/distributed/utils.cpp Normal file
View File

@@ -0,0 +1,189 @@
// Copyright © 2025 Apple Inc.
#include <netdb.h>
#include <unistd.h>
#include <sstream>
#include <thread>
#include "mlx/distributed/utils.h"
namespace mlx::core::distributed::detail {
/**
* Parse a sockaddr from an ip and port provided as strings.
*/
address_t parse_address(const std::string& ip, const std::string& port) {
struct addrinfo hints, *res;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
if (status != 0) {
std::ostringstream msg;
msg << "Can't parse address " << ip << ":" << port;
throw std::runtime_error(msg.str());
}
address_t result;
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
result.len = res->ai_addrlen;
freeaddrinfo(res);
return result;
}
/**
* Parse a sockaddr provided as an <ip>:<port> string.
*/
address_t parse_address(const std::string& ip_port) {
auto colon = ip_port.find(":");
if (colon == std::string::npos) {
std::ostringstream msg;
msg << "Can't parse address " << ip_port;
throw std::runtime_error(msg.str());
}
std::string ip(ip_port.begin(), ip_port.begin() + colon);
std::string port(ip_port.begin() + colon + 1, ip_port.end());
return parse_address(ip, port);
}
TCPSocket::TCPSocket(const char* tag) {
sock_ = socket(AF_INET, SOCK_STREAM, 0);
if (sock_ < 0) {
std::ostringstream msg;
msg << tag << " Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
}
TCPSocket::TCPSocket(TCPSocket&& s) {
sock_ = s.sock_;
s.sock_ = -1;
}
TCPSocket::TCPSocket(int s) : sock_(s) {}
TCPSocket::~TCPSocket() {
if (sock_ > 0) {
shutdown(sock_, 2);
close(sock_);
}
}
void TCPSocket::listen(const char* tag, const address_t& addr) {
int success;
// Make sure we can launch immediately after shutdown by setting the
// reuseaddr option so that we don't get address already in use errors
int enable = 1;
success = setsockopt(sock_, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
if (success < 0) {
std::ostringstream msg;
msg << tag << " Couldn't enable reuseaddr (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
success = setsockopt(sock_, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
if (success < 0) {
std::ostringstream msg;
msg << tag << " Couldn't enable reuseport (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Bind the socket to the address and port
success = bind(sock_, addr.get(), addr.len);
if (success < 0) {
std::ostringstream msg;
msg << tag << " Couldn't bind socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Prepare waiting for connections
success = ::listen(sock_, 0);
if (success < 0) {
std::ostringstream msg;
msg << tag << " Couldn't listen (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
}
TCPSocket TCPSocket::accept(const char* tag) {
int peer = ::accept(sock_, nullptr, nullptr);
if (peer < 0) {
std::ostringstream msg;
msg << tag << " Accept failed (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
return TCPSocket(peer);
}
void TCPSocket::send(const char* tag, const void* data, size_t len) {
while (len > 0) {
auto n = ::send(sock_, data, len, 0);
if (n <= 0) {
std::ostringstream msg;
msg << tag << " Send failed with errno=" << errno;
throw std::runtime_error(msg.str());
}
len -= n;
data = static_cast<const char*>(data) + n;
}
}
void TCPSocket::recv(const char* tag, void* data, size_t len) {
while (len > 0) {
auto n = ::recv(sock_, data, len, 0);
if (n <= 0) {
std::ostringstream msg;
msg << tag << " Recv failed with errno=" << errno;
throw std::runtime_error(msg.str());
}
len -= n;
data = static_cast<char*>(data) + n;
}
}
TCPSocket TCPSocket::connect(
const char* tag,
const address_t& addr,
int num_retries,
int wait,
std::function<void(int, int)> cb) {
int sock, success;
// Attempt to connect `num_retries` times with exponential backoff.
for (int attempt = 0; attempt < num_retries; attempt++) {
// Create the socket
sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << tag << " Couldn't create socket to connect (error: " << errno
<< ")";
throw std::runtime_error(msg.str());
}
success = ::connect(sock, addr.get(), addr.len);
if (success == 0) {
break;
}
cb(attempt, wait);
if (wait > 0) {
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
}
wait <<= 1;
}
if (success < 0) {
std::ostringstream msg;
msg << tag << " Couldn't connect (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
return TCPSocket(sock);
}
} // namespace mlx::core::distributed::detail

62
mlx/distributed/utils.h Normal file
View File

@@ -0,0 +1,62 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <sys/socket.h>
namespace mlx::core::distributed::detail {
struct address_t {
sockaddr_storage addr;
socklen_t len;
const sockaddr* get() const {
return (struct sockaddr*)&addr;
}
};
/**
* Parse a sockaddr from an ip and port provided as strings.
*/
address_t parse_address(const std::string& ip, const std::string& port);
/**
* Parse a sockaddr provided as an <ip>:<port> string.
*/
address_t parse_address(const std::string& ip_port);
/**
* Small wrapper over a TCP socket to simplify initiating connections.
*/
class TCPSocket {
public:
TCPSocket(const char* tag);
TCPSocket(const TCPSocket&) = delete;
TCPSocket& operator=(const TCPSocket&) = delete;
TCPSocket(TCPSocket&& s);
~TCPSocket();
void listen(const char* tag, const address_t& addr);
TCPSocket accept(const char* tag);
void send(const char* tag, const void* data, size_t len);
void recv(const char* tag, void* data, size_t len);
operator int() const {
return sock_;
}
static TCPSocket connect(
const char* tag,
const address_t& addr,
int num_retries = 1,
int wait = 0,
std::function<void(int, int)> cb = nullptr);
private:
TCPSocket(int sock);
int sock_;
};
} // namespace mlx::core::distributed::detail