mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
956 lines
26 KiB
C++
956 lines
26 KiB
C++
// Copyright © 2025 Apple Inc.
|
|
|
|
#include <infiniband/verbs.h>
|
|
#include <fstream>
|
|
#include <iostream>
|
|
#include <unistd.h>
|
|
|
|
#include <json.hpp>
|
|
|
|
#include "mlx/backend/cpu/encoder.h"
|
|
#include "mlx/distributed/distributed_impl.h"
|
|
#include "mlx/distributed/utils.h"
|
|
#include "mlx/dtype_utils.h"
|
|
|
|
constexpr const char* IBV_TAG = "[ibv]";
|
|
constexpr int NUM_BUFFERS = 2;
|
|
constexpr int BUFFER_SIZE = 4096;
|
|
constexpr int MAX_SEND_WR = 32;
|
|
constexpr int MAX_RECV_WR = 32;
|
|
constexpr int SEND_WR = 1;
|
|
constexpr int RECV_WR = 2;
|
|
constexpr int MAX_PEERS = 8;
|
|
|
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
|
using json = nlohmann::json;
|
|
namespace detail = mlx::core::distributed::detail;
|
|
namespace allocator = mlx::core::allocator;
|
|
|
|
template <typename T, typename = void>
|
|
struct is_container : std::false_type {};
|
|
|
|
template <typename T>
|
|
struct is_container<
|
|
T,
|
|
std::void_t<typename T::value_type, typename T::iterator>>
|
|
: std::true_type {};
|
|
|
|
std::ostream& operator<<(std::ostream& os, const ibv_gid& gid) {
|
|
os << std::hex << std::setfill('0');
|
|
for (int i = 0; i < 16; i += 2) {
|
|
uint16_t part = (gid.raw[i] << 8) | gid.raw[i + 1];
|
|
os << std::setw(4) << part;
|
|
if (i < 14)
|
|
os << ":";
|
|
}
|
|
os << std::dec;
|
|
return os;
|
|
}
|
|
|
|
void* page_aligned_alloc(size_t num_bytes) {
|
|
static size_t page_size = sysconf(_SC_PAGESIZE);
|
|
void * buf;
|
|
if (posix_memalign(&buf, page_size, num_bytes)) {
|
|
return nullptr;
|
|
}
|
|
return buf;
|
|
}
|
|
|
|
/**
|
|
* Contains the information that defines a destination to a remote device.
|
|
* Basically we can compute our own destination and share it with remote hosts
|
|
* over the side channel.
|
|
*/
|
|
struct Destination {
|
|
int local_id;
|
|
int queue_pair_number;
|
|
int packet_sequence_number;
|
|
ibv_gid global_identifier;
|
|
};
|
|
|
|
/**
|
|
* A buffer that can be registered to a number of protection domains.
|
|
*/
|
|
class SharedBuffer {
|
|
public:
|
|
SharedBuffer(size_t num_bytes)
|
|
: data_(page_aligned_alloc(num_bytes)),
|
|
num_bytes_(num_bytes) {}
|
|
~SharedBuffer() {
|
|
for (auto& [pd, mr] : memory_regions_) {
|
|
ibv_dereg_mr(mr);
|
|
}
|
|
if (data_ != nullptr) {
|
|
std::free(data_);
|
|
}
|
|
}
|
|
|
|
SharedBuffer(const SharedBuffer&) = delete;
|
|
SharedBuffer& operator=(const SharedBuffer&) = delete;
|
|
SharedBuffer(SharedBuffer&& b)
|
|
: data_(nullptr), num_bytes_(0) {
|
|
std::swap(data_, b.data_);
|
|
std::swap(num_bytes_, b.num_bytes_);
|
|
std::swap(memory_regions_, b.memory_regions_);
|
|
}
|
|
|
|
void register_to_protection_domain(ibv_pd* protection_domain) {
|
|
auto [it, inserted] = memory_regions_.insert({protection_domain, nullptr});
|
|
if (!inserted) {
|
|
throw std::runtime_error(
|
|
"[ibv] Buffer can be registered once per protection domain");
|
|
}
|
|
|
|
it->second = ibv_reg_mr(
|
|
protection_domain,
|
|
data_,
|
|
num_bytes_,
|
|
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ |
|
|
IBV_ACCESS_REMOTE_WRITE);
|
|
if (!it->second) {
|
|
throw std::runtime_error("[ibv] Register memory region failed");
|
|
}
|
|
}
|
|
|
|
size_t size() const {
|
|
return num_bytes_;
|
|
}
|
|
|
|
uint32_t local_key(ibv_pd* protection_domain) const {
|
|
return memory_regions_.at(protection_domain)->lkey;
|
|
}
|
|
|
|
ibv_sge to_scatter_gather_entry(ibv_pd* protection_domain) const {
|
|
ibv_sge entry;
|
|
entry.addr = reinterpret_cast<uintptr_t>(data_);
|
|
entry.length = size();
|
|
entry.lkey = local_key(protection_domain);
|
|
return entry;
|
|
}
|
|
|
|
template <typename T>
|
|
T* data() {
|
|
return static_cast<T*>(data_);
|
|
}
|
|
|
|
template <typename T>
|
|
T* begin() {
|
|
return static_cast<T*>(data_);
|
|
}
|
|
|
|
template <typename T>
|
|
T* end() {
|
|
return static_cast<T*>(data_) + size() / sizeof(T);
|
|
}
|
|
|
|
private:
|
|
void* data_;
|
|
size_t num_bytes_;
|
|
std::unordered_map<ibv_pd*, ibv_mr*> memory_regions_;
|
|
};
|
|
|
|
/**
|
|
* Manipulates an RDMA connection. Enables (among other things)
|
|
*
|
|
* - Creating a queue pair
|
|
* - Sending and receiving
|
|
* - Checking completion
|
|
*/
|
|
struct Connection {
|
|
ibv_context* ctx;
|
|
ibv_pd* protection_domain;
|
|
ibv_cq* completion_queue;
|
|
ibv_qp* queue_pair;
|
|
Destination src; // holds the local information
|
|
|
|
Connection(ibv_context* ctx_)
|
|
: ctx(ctx_),
|
|
protection_domain(nullptr),
|
|
completion_queue(nullptr),
|
|
queue_pair(nullptr) {
|
|
src.local_id = -1;
|
|
}
|
|
|
|
Connection(Connection&& c) : Connection(nullptr) {
|
|
std::swap(ctx, c.ctx);
|
|
std::swap(protection_domain, c.protection_domain);
|
|
std::swap(completion_queue, c.completion_queue);
|
|
std::swap(queue_pair, c.queue_pair);
|
|
std::swap(src, c.src);
|
|
}
|
|
|
|
Connection(const Connection&) = delete;
|
|
Connection& operator=(Connection&) = delete;
|
|
|
|
~Connection() {
|
|
if (queue_pair != nullptr) {
|
|
ibv_destroy_qp(queue_pair);
|
|
}
|
|
if (completion_queue != nullptr) {
|
|
ibv_destroy_cq(completion_queue);
|
|
}
|
|
if (protection_domain != nullptr) {
|
|
ibv_dealloc_pd(protection_domain);
|
|
}
|
|
if (ctx != nullptr) {
|
|
ibv_close_device(ctx);
|
|
}
|
|
}
|
|
|
|
void allocate_protection_domain() {
|
|
protection_domain = ibv_alloc_pd(ctx);
|
|
if (protection_domain == nullptr) {
|
|
throw std::runtime_error("[ibv] Couldn't allocate protection domain");
|
|
}
|
|
}
|
|
|
|
void create_completion_queue(int num_entries) {
|
|
completion_queue = ibv_create_cq(ctx, num_entries, nullptr, nullptr, 0);
|
|
if (completion_queue == nullptr) {
|
|
throw std::runtime_error("[ibv] Couldn't create completion queue");
|
|
}
|
|
}
|
|
|
|
void create_queue_pair() {
|
|
ibv_qp_init_attr init_attr;
|
|
init_attr.qp_context = ctx;
|
|
init_attr.qp_context = ctx;
|
|
init_attr.send_cq = completion_queue;
|
|
init_attr.recv_cq = completion_queue;
|
|
init_attr.srq = nullptr;
|
|
init_attr.cap.max_send_wr = MAX_SEND_WR;
|
|
init_attr.cap.max_recv_wr = MAX_RECV_WR;
|
|
init_attr.cap.max_send_sge = 1;
|
|
init_attr.cap.max_recv_sge = 1;
|
|
init_attr.cap.max_inline_data = 0;
|
|
init_attr.qp_type = IBV_QPT_UC;
|
|
init_attr.sq_sig_all = 0;
|
|
|
|
queue_pair = ibv_create_qp(protection_domain, &init_attr);
|
|
|
|
if (queue_pair == nullptr) {
|
|
throw std::runtime_error("[ibv] Couldn't create queue pair");
|
|
}
|
|
}
|
|
|
|
const Destination& info() {
|
|
if (queue_pair == nullptr || src.local_id >= 0) {
|
|
return src;
|
|
}
|
|
|
|
ibv_port_attr port_attr;
|
|
ibv_query_port(ctx, 1, &port_attr);
|
|
ibv_gid gid;
|
|
ibv_query_gid(ctx, 1, 1, &gid);
|
|
|
|
src.local_id = port_attr.lid;
|
|
src.queue_pair_number = queue_pair->qp_num;
|
|
src.packet_sequence_number = 7; // TODO: Change to sth random
|
|
src.global_identifier = gid;
|
|
|
|
return src;
|
|
}
|
|
|
|
void queue_pair_init() {
|
|
ibv_qp_attr attr = {};
|
|
attr.qp_state = IBV_QPS_INIT;
|
|
attr.port_num = 1;
|
|
attr.pkey_index = 0;
|
|
attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ |
|
|
IBV_ACCESS_REMOTE_WRITE;
|
|
|
|
int mask =
|
|
IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS;
|
|
|
|
if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) {
|
|
std::ostringstream msg;
|
|
msg << "[ibv] Changing queue pair to INIT failed with errno " << status;
|
|
throw std::invalid_argument(msg.str());
|
|
}
|
|
}
|
|
|
|
void queue_pair_rtr(const Destination& dst) {
|
|
ibv_qp_attr attr = {};
|
|
memset(&attr, 0, sizeof(attr));
|
|
attr.qp_state = IBV_QPS_RTR;
|
|
attr.path_mtu = IBV_MTU_1024;
|
|
attr.rq_psn = dst.packet_sequence_number;
|
|
attr.dest_qp_num = dst.queue_pair_number;
|
|
attr.ah_attr.dlid = dst.local_id;
|
|
attr.ah_attr.sl = 0;
|
|
attr.ah_attr.src_path_bits = 0;
|
|
attr.ah_attr.port_num = 1;
|
|
attr.ah_attr.is_global = 0;
|
|
|
|
if (dst.global_identifier.global.interface_id) {
|
|
attr.ah_attr.is_global = 1;
|
|
attr.ah_attr.grh.hop_limit = 1;
|
|
attr.ah_attr.grh.dgid = dst.global_identifier;
|
|
attr.ah_attr.grh.sgid_index = 1;
|
|
}
|
|
|
|
int mask = IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN |
|
|
IBV_QP_RQ_PSN;
|
|
|
|
if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) {
|
|
std::ostringstream msg;
|
|
msg << "[ibv] Changing queue pair to RTR failed with errno " << status;
|
|
throw std::invalid_argument(msg.str());
|
|
}
|
|
}
|
|
|
|
void queue_pair_rts() {
|
|
ibv_qp_attr attr = {};
|
|
attr.qp_state = IBV_QPS_RTS;
|
|
attr.sq_psn = src.packet_sequence_number;
|
|
|
|
int mask = IBV_QP_STATE | IBV_QP_SQ_PSN;
|
|
|
|
if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) {
|
|
std::ostringstream msg;
|
|
msg << "[ibv] Changing queue pair to RTS failed with errno " << status;
|
|
throw std::invalid_argument(msg.str());
|
|
}
|
|
}
|
|
|
|
void post_send(const SharedBuffer& buff, uint64_t work_request_id) {
|
|
ibv_send_wr work_request, *bad_work_request;
|
|
|
|
auto entry = buff.to_scatter_gather_entry(protection_domain);
|
|
work_request.wr_id = work_request_id;
|
|
work_request.sg_list = &entry;
|
|
work_request.num_sge = 1;
|
|
work_request.opcode = IBV_WR_SEND;
|
|
work_request.send_flags = IBV_SEND_SIGNALED;
|
|
work_request.next = nullptr;
|
|
|
|
if (int status =
|
|
ibv_post_send(queue_pair, &work_request, &bad_work_request);
|
|
status != 0) {
|
|
std::ostringstream msg;
|
|
msg << "[ibv] Send failed with error code " << status;
|
|
throw std::invalid_argument(msg.str());
|
|
}
|
|
}
|
|
|
|
void post_recv(const SharedBuffer& buff, uint64_t work_request_id) {
|
|
ibv_recv_wr work_request, *bad_work_request;
|
|
|
|
auto entry = buff.to_scatter_gather_entry(protection_domain);
|
|
work_request.wr_id = work_request_id;
|
|
work_request.sg_list = &entry;
|
|
work_request.num_sge = 1;
|
|
work_request.next = nullptr;
|
|
|
|
if (int status =
|
|
ibv_post_recv(queue_pair, &work_request, &bad_work_request);
|
|
status != 0) {
|
|
std::ostringstream msg;
|
|
msg << "[ibv] Recv failed with error code " << status;
|
|
throw std::invalid_argument(msg.str());
|
|
}
|
|
}
|
|
};
|
|
|
|
/**
|
|
* Implement a TCP side channel to exchange information about the RDMA
|
|
* connections.
|
|
*
|
|
* Implements a simple all gather where every node sends to rank 0 and rank 0
|
|
* broadcasts to every node.
|
|
*/
|
|
class SideChannel {
|
|
public:
|
|
SideChannel(int rank, int size, const char* addr) : rank_(rank), size_(size) {
|
|
auto address = detail::parse_address(addr);
|
|
|
|
if (rank_ == 0) {
|
|
detail::TCPSocket server(IBV_TAG);
|
|
server.listen(IBV_TAG, address);
|
|
|
|
for (int i = 0; i < size - 1; i++) {
|
|
sockets_.push_back(server.accept(IBV_TAG));
|
|
}
|
|
} else {
|
|
sockets_.push_back(detail::TCPSocket::connect(
|
|
IBV_TAG, address, 4, 1000, [](int attempt, int wait) {
|
|
std::cerr << IBV_TAG << " Connection attempt " << attempt
|
|
<< " waiting " << wait << " ms" << std::endl;
|
|
}));
|
|
}
|
|
}
|
|
|
|
SideChannel(const SideChannel&) = delete;
|
|
SideChannel& operator=(const SideChannel&) = delete;
|
|
|
|
SideChannel(SideChannel&& sc)
|
|
: rank_(sc.rank_), size_(sc.size_), sockets_(std::move(sc.sockets_)) {
|
|
sc.rank_ = -1;
|
|
sc.size_ = -1;
|
|
}
|
|
|
|
template <typename T>
|
|
std::vector<T> all_gather(const T& v) {
|
|
std::vector<T> result(size_);
|
|
|
|
// T is a container of stuff like std::vector or std::string
|
|
if constexpr (is_container<T>::value) {
|
|
using U = typename T::value_type;
|
|
|
|
// Share the lengths first and set the communication size to be the
|
|
// maximum length of the containers.
|
|
auto lengths = all_gather<int>(v.size());
|
|
auto max_len = *std::max_element(lengths.begin(), lengths.end());
|
|
for (auto& s : result) {
|
|
s.resize(max_len);
|
|
}
|
|
|
|
// All gather of length max_len
|
|
if (rank_ == 0) {
|
|
std::copy(v.begin(), v.end(), result[rank_].begin());
|
|
for (int i = 1; i < size_; i++) {
|
|
sockets_[i - 1].recv(IBV_TAG, result[i].data(), sizeof(U) * max_len);
|
|
}
|
|
for (int i = 1; i < size_; i++) {
|
|
for (int j = 0; j < size_; j++) {
|
|
sockets_[i - 1].send(
|
|
IBV_TAG, result[j].data(), sizeof(U) * max_len);
|
|
}
|
|
}
|
|
} else {
|
|
std::copy(v.begin(), v.end(), result[rank_].begin());
|
|
sockets_[0].send(IBV_TAG, result[rank_].data(), sizeof(U) * max_len);
|
|
for (int i = 0; i < size_; i++) {
|
|
sockets_[0].recv(IBV_TAG, result[i].data(), sizeof(U) * max_len);
|
|
}
|
|
}
|
|
|
|
// Resize the outputs back to the original length
|
|
for (int i = 0; i < size_; i++) {
|
|
result[i].resize(lengths[i]);
|
|
}
|
|
}
|
|
|
|
// T is a scalar
|
|
else {
|
|
if (rank_ == 0) {
|
|
result[rank_] = v;
|
|
for (int i = 1; i < size_; i++) {
|
|
sockets_[i - 1].recv(IBV_TAG, &result[i], sizeof(T));
|
|
}
|
|
for (int i = 1; i < size_; i++) {
|
|
sockets_[i - 1].send(IBV_TAG, result.data(), size_ * sizeof(T));
|
|
}
|
|
} else {
|
|
sockets_[0].send(IBV_TAG, &v, sizeof(T));
|
|
sockets_[0].recv(IBV_TAG, result.data(), size_ * sizeof(T));
|
|
}
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
private:
|
|
int rank_;
|
|
int size_;
|
|
std::vector<detail::TCPSocket> sockets_;
|
|
};
|
|
|
|
/**
|
|
* Manages a set of connections. Among other things it uses a side channel to
|
|
* exchange the necessary information and then configure the connections to be
|
|
* ready for RDMA operations.
|
|
*/
|
|
class ConnectionManager {
|
|
public:
|
|
ConnectionManager(
|
|
int rank,
|
|
const std::vector<std::string>& device_names,
|
|
const char* coordinator_addr)
|
|
: rank_(rank),
|
|
size_(device_names.size()),
|
|
side_channel_(rank_, size_, coordinator_addr) {
|
|
create_contexts(device_names);
|
|
if (connections_[rank_].ctx != nullptr) {
|
|
throw std::runtime_error("[ibv] Malformed device file");
|
|
}
|
|
}
|
|
|
|
int rank() const {
|
|
return rank_;
|
|
}
|
|
|
|
int size() const {
|
|
return size_;
|
|
}
|
|
|
|
/**
|
|
* Performs the connection initialization. Namely, after this call all
|
|
* Connection objects should have a queue pair in RTS state.
|
|
*/
|
|
void initialize(int num_buffers, size_t num_bytes) {
|
|
// Create the queue pairs
|
|
for (auto& conn : connections_) {
|
|
if (conn.ctx == nullptr) {
|
|
continue;
|
|
}
|
|
conn.allocate_protection_domain();
|
|
conn.create_completion_queue(MAX_SEND_WR + MAX_RECV_WR);
|
|
conn.create_queue_pair();
|
|
}
|
|
|
|
allocate_buffers(num_buffers, num_bytes);
|
|
|
|
// Gather the information to be exchanged
|
|
std::vector<Destination> info;
|
|
for (auto& conn : connections_) {
|
|
info.emplace_back(conn.info());
|
|
}
|
|
auto all_infos = side_channel_.all_gather(info);
|
|
|
|
// Transition queue pairs to RTS
|
|
for (int peer = 0; peer < size_; peer++) {
|
|
if (peer == rank_) {
|
|
continue;
|
|
}
|
|
auto peer_info = all_infos[peer][rank_];
|
|
connections_[peer].queue_pair_init();
|
|
connections_[peer].queue_pair_rtr(peer_info);
|
|
connections_[peer].queue_pair_rts();
|
|
}
|
|
}
|
|
|
|
void allocate_buffers(int num_buffers, size_t num_bytes) {
|
|
// Deregister any buffers and free the memory
|
|
buffers_.clear();
|
|
|
|
// Allocate the memory
|
|
for (int i = 0; i < num_buffers; i++) {
|
|
for (int j = 0; j < size_; j++) {
|
|
buffers_.emplace_back(num_bytes);
|
|
}
|
|
}
|
|
|
|
for (int i = 0; i < num_buffers; i++) {
|
|
for (int j = 0; j < size_; j++) {
|
|
// This is our send buffer so register it with all pds so we can send
|
|
// it to all connected devices.
|
|
if (j == rank_) {
|
|
for (auto& conn : connections_) {
|
|
if (conn.ctx != nullptr) {
|
|
buffers_[i * size_ + j].register_to_protection_domain(
|
|
conn.protection_domain);
|
|
}
|
|
}
|
|
}
|
|
|
|
// This is the recv buffer from rank j so register it to rank j's
|
|
// protection domain.
|
|
else {
|
|
buffers_[i * size_ + j].register_to_protection_domain(
|
|
connections_[j].protection_domain);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void send_to(int rank, int buff) {
|
|
connections_[rank].post_send(
|
|
buffers_[buff * size_ + rank_],
|
|
SEND_WR << 16 | buff << 8 | rank);
|
|
}
|
|
|
|
void recv_from(int rank, int buff) {
|
|
connections_[rank].post_recv(
|
|
buffers_[buff * size_ + rank], RECV_WR << 16 | buff << 8 | rank);
|
|
}
|
|
|
|
/**
|
|
* Poll all connections and save the work completions and return the
|
|
* corresponding length.
|
|
*/
|
|
int poll(int num_completions, ibv_wc* work_completions) {
|
|
int completions = 0;
|
|
for (int r = 0; r < size_; r++) {
|
|
if (r == rank_) {
|
|
continue;
|
|
}
|
|
if (completions >= num_completions) {
|
|
return completions;
|
|
}
|
|
|
|
int c = ibv_poll_cq(
|
|
connections_[r].completion_queue,
|
|
num_completions - completions,
|
|
work_completions + completions);
|
|
|
|
completions += c;
|
|
}
|
|
return completions;
|
|
}
|
|
|
|
/**
|
|
*
|
|
*/
|
|
int poll(int rank, int num_completions, ibv_wc* work_completions) {
|
|
return ibv_poll_cq(
|
|
connections_[rank].completion_queue,
|
|
num_completions,
|
|
work_completions);
|
|
}
|
|
|
|
SharedBuffer& send_buffer(int buff) {
|
|
return buffers_[buff * size_ + rank_];
|
|
}
|
|
|
|
SharedBuffer& buffer(int rank, int buff) {
|
|
return buffers_[buff * size_ + rank];
|
|
}
|
|
|
|
void barrier() {
|
|
side_channel_.all_gather<int>(0);
|
|
}
|
|
|
|
private:
|
|
void create_contexts(const std::vector<std::string>& device_names) {
|
|
int num_devices = 0;
|
|
ibv_device** devices = ibv_get_device_list(&num_devices);
|
|
for (auto& name : device_names) {
|
|
// Empty so add a nullptr context
|
|
if (name.empty()) {
|
|
connections_.emplace_back(nullptr);
|
|
continue;
|
|
}
|
|
|
|
// Search for the name and try to open the device
|
|
for (int i = 0; i < num_devices; i++) {
|
|
if (name == ibv_get_device_name(devices[i])) {
|
|
auto ctx = ibv_open_device(devices[i]);
|
|
if (ctx == nullptr) {
|
|
std::ostringstream msg;
|
|
msg << "[ibv] Could not open device " << name;
|
|
throw std::runtime_error(msg.str());
|
|
}
|
|
connections_.emplace_back(ctx);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
ibv_free_device_list(devices);
|
|
}
|
|
|
|
int rank_;
|
|
int size_;
|
|
SideChannel side_channel_;
|
|
std::vector<Connection> connections_;
|
|
std::vector<SharedBuffer> buffers_;
|
|
};
|
|
|
|
template <typename T>
|
|
struct SumOp {
|
|
void operator()(const T* input, T* output, size_t N) const {
|
|
while (N-- > 0) {
|
|
*output += *input;
|
|
input++;
|
|
output++;
|
|
}
|
|
}
|
|
};
|
|
|
|
std::vector<std::string> load_device_names(int rank, const char* dev_file) {
|
|
std::vector<std::string> device_names;
|
|
std::ifstream f(dev_file);
|
|
|
|
json devices = json::parse(f);
|
|
devices = devices[rank];
|
|
for (auto it = devices.begin(); it != devices.end(); it++) {
|
|
std::string n;
|
|
if (!it->is_null()) {
|
|
n = *it;
|
|
}
|
|
device_names.emplace_back(std::move(n));
|
|
}
|
|
|
|
return device_names;
|
|
}
|
|
|
|
namespace mlx::core::distributed::ibv {
|
|
|
|
class IBVGroup : public GroupImpl {
|
|
public:
|
|
IBVGroup(ConnectionManager cm) : cm_(std::move(cm)), rank_(cm.rank()), size_(cm.size()) {}
|
|
|
|
Stream communication_stream(StreamOrDevice s) override {
|
|
return to_stream(s, Device::cpu);
|
|
}
|
|
|
|
int rank() override {
|
|
return cm_.rank();
|
|
}
|
|
|
|
int size() override {
|
|
return cm_.size();
|
|
}
|
|
|
|
void all_sum(const array& input, array& output, Stream stream) override {
|
|
dispatch_all_types(output.dtype(), [&](auto type_tag) {
|
|
using T = MLX_GET_TYPE(type_tag);
|
|
all_reduce<T>(input, output, stream, SumOp<T>{});
|
|
});
|
|
}
|
|
|
|
void all_max(const array& input, array& output, Stream stream) override {
|
|
}
|
|
|
|
void all_min(const array& input, array& output, Stream stream) override {
|
|
}
|
|
|
|
void all_gather(const array& input, array& output, Stream stream) override {
|
|
}
|
|
|
|
void send(const array& input, int dst, Stream stream) override {
|
|
}
|
|
|
|
void recv(array& out, int src, Stream stream) override {
|
|
}
|
|
|
|
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
|
throw std::runtime_error("[ibv] Group split not supported.");
|
|
}
|
|
|
|
private:
|
|
void post_recv_all(int buffer) {
|
|
for (int i = 0; i < size_; i++) {
|
|
if (i == rank_) {
|
|
continue;
|
|
}
|
|
cm_.recv_from(i, buffer);
|
|
}
|
|
}
|
|
|
|
void post_send_all(int buffer) {
|
|
for (int i = 0; i < size_; i++) {
|
|
if (i == rank_) {
|
|
continue;
|
|
}
|
|
cm_.send_to(i, buffer);
|
|
}
|
|
}
|
|
|
|
template <typename T, typename ReduceOp>
|
|
void all_reduce(
|
|
const array& input,
|
|
array& output,
|
|
Stream stream,
|
|
ReduceOp reduce_op) {
|
|
auto in_ptr = input.data<T>();
|
|
auto out_ptr = output.data<T>();
|
|
auto& encoder = cpu::get_command_encoder(stream);
|
|
encoder.set_output_array(output);
|
|
encoder.dispatch([in_ptr, out_ptr, size = input.size(), this, reduce_op]() {
|
|
// If not inplace all reduce then copy the input to the output first
|
|
if (in_ptr != out_ptr) {
|
|
std::memcpy(out_ptr, in_ptr, size * sizeof(T));
|
|
}
|
|
|
|
// Fully connected all reduce
|
|
T* data = out_ptr;
|
|
constexpr int64_t N = BUFFER_SIZE / sizeof(T);
|
|
int64_t total = static_cast<int64_t>(size);
|
|
int64_t offset = N;
|
|
int a = 0, b = 1;
|
|
|
|
int mask_init = 1 << rank_;
|
|
int mask_target = (1 << size_) - 1;
|
|
|
|
// Handle the first piece of data.
|
|
post_recv_all(a);
|
|
std::copy(
|
|
data,
|
|
data + std::min(N, total),
|
|
cm_.send_buffer(a).begin<T>());
|
|
post_send_all(a);
|
|
|
|
int mask_a_send = mask_init;
|
|
int mask_a_recv = mask_init;
|
|
int mask_b_send = mask_init;
|
|
int mask_b_recv = mask_init;
|
|
|
|
while (offset < total) {
|
|
// While the previous chunk is in flight copy to the next send buffer
|
|
std::copy(
|
|
data + offset,
|
|
data + std::min(offset + N, total),
|
|
cm_.send_buffer(b).begin<T>());
|
|
|
|
// Send if the previous send is already done
|
|
for (int i = 0, m = 1; i < size_; i++, m *= 2) {
|
|
if (i == rank_) {
|
|
continue;
|
|
}
|
|
if (mask_a_send & m) {
|
|
cm_.send_to(i, b);
|
|
}
|
|
}
|
|
|
|
// Recv the next buffer if the previous one is already done
|
|
for (int i = 0, m = 1; i < size_; i++, m *= 2) {
|
|
if (i == rank_) {
|
|
continue;
|
|
}
|
|
if (mask_a_recv & m) {
|
|
cm_.recv_from(i, b);
|
|
reduce_op(
|
|
cm_.buffer(i, a).begin<T>(),
|
|
data + std::max(offset - N, 0LL),
|
|
std::min(N, total - offset));
|
|
}
|
|
}
|
|
|
|
// Loop until this chunk is all done
|
|
while (mask_a_send != mask_target || mask_a_recv != mask_target) {
|
|
ibv_wc wc[8];
|
|
int n = cm_.poll(8, wc);
|
|
for (int i = 0; i < n; i++) {
|
|
int work_type = wc[i].wr_id >> 16;
|
|
int buff = (wc[i].wr_id >> 8) & 0xff;
|
|
int rank = wc[i].wr_id & 0xff;
|
|
|
|
if (work_type == SEND_WR) {
|
|
if (buff == a) {
|
|
cm_.send_to(rank, b);
|
|
mask_a_send |= 1 << rank;
|
|
} else {
|
|
mask_b_send |= 1 << rank;
|
|
}
|
|
} else {
|
|
if (buff == a) {
|
|
cm_.recv_from(rank, b);
|
|
mask_a_recv |= 1 << rank;
|
|
reduce_op(
|
|
cm_.buffer(rank, a).begin<T>(),
|
|
data + std::max(offset - N, 0LL),
|
|
std::min(N, total - offset));
|
|
} else {
|
|
mask_b_recv |= 1 << rank;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
std::swap(a, b);
|
|
mask_a_send = mask_b_send;
|
|
mask_a_recv = mask_b_recv;
|
|
mask_b_send = mask_init;
|
|
mask_b_recv = mask_init;
|
|
offset += N;
|
|
}
|
|
|
|
{
|
|
while (mask_a_send != mask_target || mask_a_recv != mask_target) {
|
|
ibv_wc wc[8];
|
|
int n = cm_.poll(8, wc);
|
|
for (int i = 0; i < n; i++) {
|
|
int work_type = wc[i].wr_id >> 16;
|
|
int buff = (wc[i].wr_id >> 8) & 0xff;
|
|
int rank = wc[i].wr_id & 0xff;
|
|
|
|
if (work_type == SEND_WR) {
|
|
mask_a_send |= 1 << rank;
|
|
} else {
|
|
mask_a_recv |= 1 << rank;
|
|
reduce_op(
|
|
cm_.buffer(rank, a).begin<T>(),
|
|
data + offset - N,
|
|
std::min(N, total + N - offset));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
ConnectionManager cm_;
|
|
int rank_;
|
|
int size_;
|
|
};
|
|
|
|
bool is_available() {
|
|
return true;
|
|
}
|
|
|
|
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
|
const char* dev_file = std::getenv("MLX_IBV_DEVICES");
|
|
const char* coordinator = std::getenv("MLX_IBV_COORDINATOR");
|
|
const char* rank_str = std::getenv("MLX_RANK");
|
|
const char* ring_verbose = std::getenv("MLX_IBV_VERBOSE");
|
|
|
|
if (!dev_file || !coordinator || !rank_str) {
|
|
if (strict) {
|
|
std::ostringstream msg;
|
|
msg << "[ibv] You need to provide via environment variables a rank (MLX_RANK), "
|
|
<< "a device file (MLX_IBV_DEVICES) and a coordinator ip/port (MLX_IBV_COORDINATOR) "
|
|
<< "but provided MLX_RANK=\"" << ((rank_str) ? rank_str : "")
|
|
<< "\", MLX_IBV_DEVICES=\"" << ((dev_file) ? dev_file : "")
|
|
<< "\" and MLX_IBV_COORDINATOR=\""
|
|
<< ((coordinator) ? coordinator : "");
|
|
throw std::runtime_error(msg.str());
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
auto rank = std::atoi(rank_str);
|
|
auto device_names = load_device_names(rank, dev_file);
|
|
|
|
auto cm = ConnectionManager(rank, device_names, coordinator);
|
|
if (cm.size() > MAX_PEERS) {
|
|
std::ostringstream msg;
|
|
msg << "[ibv] The maximum number of supported peers is "
|
|
<< MAX_PEERS << " but " << cm.size() << " was provided";
|
|
throw std::runtime_error(msg.str());
|
|
}
|
|
|
|
cm.initialize(NUM_BUFFERS, BUFFER_SIZE);
|
|
cm.barrier();
|
|
|
|
return std::make_shared<IBVGroup>(std::move(cm));
|
|
|
|
//cm.recv_from(rank ^ 1, 0);
|
|
//cm.barrier();
|
|
//std::fill(
|
|
// cm.buffer(rank, 0).begin<int>(),
|
|
// cm.buffer(rank, 0).end<int>(),
|
|
// 2 * rank + 1);
|
|
//cm.barrier();
|
|
//cm.send_to(rank ^ 1, 0);
|
|
|
|
//ibv_wc wc[8];
|
|
//bool sent = false, recvd = false;
|
|
//while (!(sent && recvd)) {
|
|
// int num_completions = cm.poll(8, wc);
|
|
// for (int i = 0; i < num_completions; i++) {
|
|
// if (wc[i].status != IBV_WC_SUCCESS) {
|
|
// std::cout << "Error " << wc[i].status << std::endl;
|
|
// }
|
|
|
|
// int work_type = wc[i].wr_id >> 16;
|
|
// int buff = (wc[i].wr_id >> 8) & 0xff;
|
|
// int rank = wc[i].wr_id & 0xff;
|
|
|
|
// std::cout << work_type << " " << buff << " " << rank << std::endl;
|
|
// print_wc(wc[i]);
|
|
|
|
// sent |= (work_type == SEND_WR);
|
|
// recvd |= (work_type == RECV_WR);
|
|
// }
|
|
//}
|
|
|
|
//std::cout << rank << " ours " << *cm.buffer(rank, 0).data<int>() << std::endl;
|
|
//std::cout << rank << " theirs " << *cm.buffer(rank ^ 1, 0).data<int>()
|
|
// << std::endl;
|
|
|
|
//return nullptr;
|
|
}
|
|
|
|
} // namespace mlx::core::distributed::ibv
|