mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Initial working all reduce
This commit is contained in:
@@ -119,6 +119,10 @@ if(MLX_BUILD_METAL)
|
||||
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
OUTPUT_VARIABLE MACOS_SDK_VERSION
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-path"
|
||||
OUTPUT_VARIABLE CMAKE_OSX_SYSROOT
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||
message(
|
||||
|
||||
@@ -4,6 +4,11 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp)
|
||||
|
||||
if(MLX_BUILD_CPU AND NOT WIN32)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ibv)
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/distributed/ibv/ibv.h"
|
||||
#include "mlx/distributed/mpi/mpi.h"
|
||||
#include "mlx/distributed/nccl/nccl.h"
|
||||
#include "mlx/distributed/ring/ring.h"
|
||||
@@ -102,7 +103,8 @@ class EmptyGroup : public GroupImpl {
|
||||
} // namespace detail
|
||||
|
||||
bool is_available() {
|
||||
return mpi::is_available() || ring::is_available() || nccl::is_available();
|
||||
return mpi::is_available() || ring::is_available() || nccl::is_available() ||
|
||||
ibv::is_available();
|
||||
}
|
||||
|
||||
int Group::rank() const {
|
||||
@@ -135,6 +137,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
||||
group = ring::init(strict);
|
||||
} else if (bk == "nccl") {
|
||||
group = nccl::init(strict);
|
||||
} else if (bk == "ibv") {
|
||||
group = ibv::init(strict);
|
||||
} else if (bk == "any") {
|
||||
if (mlx::core::cu::is_available()) {
|
||||
group = nccl::init(false);
|
||||
@@ -148,13 +152,17 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
||||
group = mpi::init(false);
|
||||
bk_ = "mpi";
|
||||
}
|
||||
if (group == nullptr) {
|
||||
group = ibv::init(false);
|
||||
bk_ = "ibv";
|
||||
}
|
||||
if (group == nullptr && strict) {
|
||||
throw std::runtime_error("[distributed] Couldn't initialize any backend");
|
||||
}
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "[distributed] The only valid values for backend are 'any', 'mpi' "
|
||||
<< "and 'ring' but '" << bk << "' was provided.";
|
||||
msg << "[distributed] The only valid values for backend are 'any', 'mpi', 'nccl', "
|
||||
<< "'ibv' and 'ring' but '" << bk << "' was provided.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
|
||||
2
mlx/distributed/ibv/CMakeLists.txt
Normal file
2
mlx/distributed/ibv/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ibv.cpp)
|
||||
target_link_libraries(mlx PRIVATE rdma)
|
||||
955
mlx/distributed/ibv/ibv.cpp
Normal file
955
mlx/distributed/ibv/ibv.cpp
Normal file
@@ -0,0 +1,955 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <infiniband/verbs.h>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <json.hpp>
|
||||
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/distributed/utils.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
constexpr const char* IBV_TAG = "[ibv]";
|
||||
constexpr int NUM_BUFFERS = 2;
|
||||
constexpr int BUFFER_SIZE = 4096;
|
||||
constexpr int MAX_SEND_WR = 32;
|
||||
constexpr int MAX_RECV_WR = 32;
|
||||
constexpr int SEND_WR = 1;
|
||||
constexpr int RECV_WR = 2;
|
||||
constexpr int MAX_PEERS = 8;
|
||||
|
||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||
using json = nlohmann::json;
|
||||
namespace detail = mlx::core::distributed::detail;
|
||||
namespace allocator = mlx::core::allocator;
|
||||
|
||||
template <typename T, typename = void>
|
||||
struct is_container : std::false_type {};
|
||||
|
||||
template <typename T>
|
||||
struct is_container<
|
||||
T,
|
||||
std::void_t<typename T::value_type, typename T::iterator>>
|
||||
: std::true_type {};
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const ibv_gid& gid) {
|
||||
os << std::hex << std::setfill('0');
|
||||
for (int i = 0; i < 16; i += 2) {
|
||||
uint16_t part = (gid.raw[i] << 8) | gid.raw[i + 1];
|
||||
os << std::setw(4) << part;
|
||||
if (i < 14)
|
||||
os << ":";
|
||||
}
|
||||
os << std::dec;
|
||||
return os;
|
||||
}
|
||||
|
||||
void* page_aligned_alloc(size_t num_bytes) {
|
||||
static size_t page_size = sysconf(_SC_PAGESIZE);
|
||||
void * buf;
|
||||
if (posix_memalign(&buf, page_size, num_bytes)) {
|
||||
return nullptr;
|
||||
}
|
||||
return buf;
|
||||
}
|
||||
|
||||
/**
|
||||
* Contains the information that defines a destination to a remote device.
|
||||
* Basically we can compute our own destination and share it with remote hosts
|
||||
* over the side channel.
|
||||
*/
|
||||
struct Destination {
|
||||
int local_id;
|
||||
int queue_pair_number;
|
||||
int packet_sequence_number;
|
||||
ibv_gid global_identifier;
|
||||
};
|
||||
|
||||
/**
|
||||
* A buffer that can be registered to a number of protection domains.
|
||||
*/
|
||||
class SharedBuffer {
|
||||
public:
|
||||
SharedBuffer(size_t num_bytes)
|
||||
: data_(page_aligned_alloc(num_bytes)),
|
||||
num_bytes_(num_bytes) {}
|
||||
~SharedBuffer() {
|
||||
for (auto& [pd, mr] : memory_regions_) {
|
||||
ibv_dereg_mr(mr);
|
||||
}
|
||||
if (data_ != nullptr) {
|
||||
std::free(data_);
|
||||
}
|
||||
}
|
||||
|
||||
SharedBuffer(const SharedBuffer&) = delete;
|
||||
SharedBuffer& operator=(const SharedBuffer&) = delete;
|
||||
SharedBuffer(SharedBuffer&& b)
|
||||
: data_(nullptr), num_bytes_(0) {
|
||||
std::swap(data_, b.data_);
|
||||
std::swap(num_bytes_, b.num_bytes_);
|
||||
std::swap(memory_regions_, b.memory_regions_);
|
||||
}
|
||||
|
||||
void register_to_protection_domain(ibv_pd* protection_domain) {
|
||||
auto [it, inserted] = memory_regions_.insert({protection_domain, nullptr});
|
||||
if (!inserted) {
|
||||
throw std::runtime_error(
|
||||
"[ibv] Buffer can be registered once per protection domain");
|
||||
}
|
||||
|
||||
it->second = ibv_reg_mr(
|
||||
protection_domain,
|
||||
data_,
|
||||
num_bytes_,
|
||||
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ |
|
||||
IBV_ACCESS_REMOTE_WRITE);
|
||||
if (!it->second) {
|
||||
throw std::runtime_error("[ibv] Register memory region failed");
|
||||
}
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
return num_bytes_;
|
||||
}
|
||||
|
||||
uint32_t local_key(ibv_pd* protection_domain) const {
|
||||
return memory_regions_.at(protection_domain)->lkey;
|
||||
}
|
||||
|
||||
ibv_sge to_scatter_gather_entry(ibv_pd* protection_domain) const {
|
||||
ibv_sge entry;
|
||||
entry.addr = reinterpret_cast<uintptr_t>(data_);
|
||||
entry.length = size();
|
||||
entry.lkey = local_key(protection_domain);
|
||||
return entry;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T* data() {
|
||||
return static_cast<T*>(data_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T* begin() {
|
||||
return static_cast<T*>(data_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T* end() {
|
||||
return static_cast<T*>(data_) + size() / sizeof(T);
|
||||
}
|
||||
|
||||
private:
|
||||
void* data_;
|
||||
size_t num_bytes_;
|
||||
std::unordered_map<ibv_pd*, ibv_mr*> memory_regions_;
|
||||
};
|
||||
|
||||
/**
|
||||
* Manipulates an RDMA connection. Enables (among other things)
|
||||
*
|
||||
* - Creating a queue pair
|
||||
* - Sending and receiving
|
||||
* - Checking completion
|
||||
*/
|
||||
struct Connection {
|
||||
ibv_context* ctx;
|
||||
ibv_pd* protection_domain;
|
||||
ibv_cq* completion_queue;
|
||||
ibv_qp* queue_pair;
|
||||
Destination src; // holds the local information
|
||||
|
||||
Connection(ibv_context* ctx_)
|
||||
: ctx(ctx_),
|
||||
protection_domain(nullptr),
|
||||
completion_queue(nullptr),
|
||||
queue_pair(nullptr) {
|
||||
src.local_id = -1;
|
||||
}
|
||||
|
||||
Connection(Connection&& c) : Connection(nullptr) {
|
||||
std::swap(ctx, c.ctx);
|
||||
std::swap(protection_domain, c.protection_domain);
|
||||
std::swap(completion_queue, c.completion_queue);
|
||||
std::swap(queue_pair, c.queue_pair);
|
||||
std::swap(src, c.src);
|
||||
}
|
||||
|
||||
Connection(const Connection&) = delete;
|
||||
Connection& operator=(Connection&) = delete;
|
||||
|
||||
~Connection() {
|
||||
if (queue_pair != nullptr) {
|
||||
ibv_destroy_qp(queue_pair);
|
||||
}
|
||||
if (completion_queue != nullptr) {
|
||||
ibv_destroy_cq(completion_queue);
|
||||
}
|
||||
if (protection_domain != nullptr) {
|
||||
ibv_dealloc_pd(protection_domain);
|
||||
}
|
||||
if (ctx != nullptr) {
|
||||
ibv_close_device(ctx);
|
||||
}
|
||||
}
|
||||
|
||||
void allocate_protection_domain() {
|
||||
protection_domain = ibv_alloc_pd(ctx);
|
||||
if (protection_domain == nullptr) {
|
||||
throw std::runtime_error("[ibv] Couldn't allocate protection domain");
|
||||
}
|
||||
}
|
||||
|
||||
void create_completion_queue(int num_entries) {
|
||||
completion_queue = ibv_create_cq(ctx, num_entries, nullptr, nullptr, 0);
|
||||
if (completion_queue == nullptr) {
|
||||
throw std::runtime_error("[ibv] Couldn't create completion queue");
|
||||
}
|
||||
}
|
||||
|
||||
void create_queue_pair() {
|
||||
ibv_qp_init_attr init_attr;
|
||||
init_attr.qp_context = ctx;
|
||||
init_attr.qp_context = ctx;
|
||||
init_attr.send_cq = completion_queue;
|
||||
init_attr.recv_cq = completion_queue;
|
||||
init_attr.srq = nullptr;
|
||||
init_attr.cap.max_send_wr = MAX_SEND_WR;
|
||||
init_attr.cap.max_recv_wr = MAX_RECV_WR;
|
||||
init_attr.cap.max_send_sge = 1;
|
||||
init_attr.cap.max_recv_sge = 1;
|
||||
init_attr.cap.max_inline_data = 0;
|
||||
init_attr.qp_type = IBV_QPT_UC;
|
||||
init_attr.sq_sig_all = 0;
|
||||
|
||||
queue_pair = ibv_create_qp(protection_domain, &init_attr);
|
||||
|
||||
if (queue_pair == nullptr) {
|
||||
throw std::runtime_error("[ibv] Couldn't create queue pair");
|
||||
}
|
||||
}
|
||||
|
||||
const Destination& info() {
|
||||
if (queue_pair == nullptr || src.local_id >= 0) {
|
||||
return src;
|
||||
}
|
||||
|
||||
ibv_port_attr port_attr;
|
||||
ibv_query_port(ctx, 1, &port_attr);
|
||||
ibv_gid gid;
|
||||
ibv_query_gid(ctx, 1, 1, &gid);
|
||||
|
||||
src.local_id = port_attr.lid;
|
||||
src.queue_pair_number = queue_pair->qp_num;
|
||||
src.packet_sequence_number = 7; // TODO: Change to sth random
|
||||
src.global_identifier = gid;
|
||||
|
||||
return src;
|
||||
}
|
||||
|
||||
void queue_pair_init() {
|
||||
ibv_qp_attr attr = {};
|
||||
attr.qp_state = IBV_QPS_INIT;
|
||||
attr.port_num = 1;
|
||||
attr.pkey_index = 0;
|
||||
attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ |
|
||||
IBV_ACCESS_REMOTE_WRITE;
|
||||
|
||||
int mask =
|
||||
IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS;
|
||||
|
||||
if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ibv] Changing queue pair to INIT failed with errno " << status;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
void queue_pair_rtr(const Destination& dst) {
|
||||
ibv_qp_attr attr = {};
|
||||
memset(&attr, 0, sizeof(attr));
|
||||
attr.qp_state = IBV_QPS_RTR;
|
||||
attr.path_mtu = IBV_MTU_1024;
|
||||
attr.rq_psn = dst.packet_sequence_number;
|
||||
attr.dest_qp_num = dst.queue_pair_number;
|
||||
attr.ah_attr.dlid = dst.local_id;
|
||||
attr.ah_attr.sl = 0;
|
||||
attr.ah_attr.src_path_bits = 0;
|
||||
attr.ah_attr.port_num = 1;
|
||||
attr.ah_attr.is_global = 0;
|
||||
|
||||
if (dst.global_identifier.global.interface_id) {
|
||||
attr.ah_attr.is_global = 1;
|
||||
attr.ah_attr.grh.hop_limit = 1;
|
||||
attr.ah_attr.grh.dgid = dst.global_identifier;
|
||||
attr.ah_attr.grh.sgid_index = 1;
|
||||
}
|
||||
|
||||
int mask = IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN |
|
||||
IBV_QP_RQ_PSN;
|
||||
|
||||
if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ibv] Changing queue pair to RTR failed with errno " << status;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
void queue_pair_rts() {
|
||||
ibv_qp_attr attr = {};
|
||||
attr.qp_state = IBV_QPS_RTS;
|
||||
attr.sq_psn = src.packet_sequence_number;
|
||||
|
||||
int mask = IBV_QP_STATE | IBV_QP_SQ_PSN;
|
||||
|
||||
if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ibv] Changing queue pair to RTS failed with errno " << status;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
void post_send(const SharedBuffer& buff, uint64_t work_request_id) {
|
||||
ibv_send_wr work_request, *bad_work_request;
|
||||
|
||||
auto entry = buff.to_scatter_gather_entry(protection_domain);
|
||||
work_request.wr_id = work_request_id;
|
||||
work_request.sg_list = &entry;
|
||||
work_request.num_sge = 1;
|
||||
work_request.opcode = IBV_WR_SEND;
|
||||
work_request.send_flags = IBV_SEND_SIGNALED;
|
||||
work_request.next = nullptr;
|
||||
|
||||
if (int status =
|
||||
ibv_post_send(queue_pair, &work_request, &bad_work_request);
|
||||
status != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ibv] Send failed with error code " << status;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
void post_recv(const SharedBuffer& buff, uint64_t work_request_id) {
|
||||
ibv_recv_wr work_request, *bad_work_request;
|
||||
|
||||
auto entry = buff.to_scatter_gather_entry(protection_domain);
|
||||
work_request.wr_id = work_request_id;
|
||||
work_request.sg_list = &entry;
|
||||
work_request.num_sge = 1;
|
||||
work_request.next = nullptr;
|
||||
|
||||
if (int status =
|
||||
ibv_post_recv(queue_pair, &work_request, &bad_work_request);
|
||||
status != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ibv] Recv failed with error code " << status;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Implement a TCP side channel to exchange information about the RDMA
|
||||
* connections.
|
||||
*
|
||||
* Implements a simple all gather where every node sends to rank 0 and rank 0
|
||||
* broadcasts to every node.
|
||||
*/
|
||||
class SideChannel {
|
||||
public:
|
||||
SideChannel(int rank, int size, const char* addr) : rank_(rank), size_(size) {
|
||||
auto address = detail::parse_address(addr);
|
||||
|
||||
if (rank_ == 0) {
|
||||
detail::TCPSocket server(IBV_TAG);
|
||||
server.listen(IBV_TAG, address);
|
||||
|
||||
for (int i = 0; i < size - 1; i++) {
|
||||
sockets_.push_back(server.accept(IBV_TAG));
|
||||
}
|
||||
} else {
|
||||
sockets_.push_back(detail::TCPSocket::connect(
|
||||
IBV_TAG, address, 4, 1000, [](int attempt, int wait) {
|
||||
std::cerr << IBV_TAG << " Connection attempt " << attempt
|
||||
<< " waiting " << wait << " ms" << std::endl;
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
SideChannel(const SideChannel&) = delete;
|
||||
SideChannel& operator=(const SideChannel&) = delete;
|
||||
|
||||
SideChannel(SideChannel&& sc)
|
||||
: rank_(sc.rank_), size_(sc.size_), sockets_(std::move(sc.sockets_)) {
|
||||
sc.rank_ = -1;
|
||||
sc.size_ = -1;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> all_gather(const T& v) {
|
||||
std::vector<T> result(size_);
|
||||
|
||||
// T is a container of stuff like std::vector or std::string
|
||||
if constexpr (is_container<T>::value) {
|
||||
using U = typename T::value_type;
|
||||
|
||||
// Share the lengths first and set the communication size to be the
|
||||
// maximum length of the containers.
|
||||
auto lengths = all_gather<int>(v.size());
|
||||
auto max_len = *std::max_element(lengths.begin(), lengths.end());
|
||||
for (auto& s : result) {
|
||||
s.resize(max_len);
|
||||
}
|
||||
|
||||
// All gather of length max_len
|
||||
if (rank_ == 0) {
|
||||
std::copy(v.begin(), v.end(), result[rank_].begin());
|
||||
for (int i = 1; i < size_; i++) {
|
||||
sockets_[i - 1].recv(IBV_TAG, result[i].data(), sizeof(U) * max_len);
|
||||
}
|
||||
for (int i = 1; i < size_; i++) {
|
||||
for (int j = 0; j < size_; j++) {
|
||||
sockets_[i - 1].send(
|
||||
IBV_TAG, result[j].data(), sizeof(U) * max_len);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::copy(v.begin(), v.end(), result[rank_].begin());
|
||||
sockets_[0].send(IBV_TAG, result[rank_].data(), sizeof(U) * max_len);
|
||||
for (int i = 0; i < size_; i++) {
|
||||
sockets_[0].recv(IBV_TAG, result[i].data(), sizeof(U) * max_len);
|
||||
}
|
||||
}
|
||||
|
||||
// Resize the outputs back to the original length
|
||||
for (int i = 0; i < size_; i++) {
|
||||
result[i].resize(lengths[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// T is a scalar
|
||||
else {
|
||||
if (rank_ == 0) {
|
||||
result[rank_] = v;
|
||||
for (int i = 1; i < size_; i++) {
|
||||
sockets_[i - 1].recv(IBV_TAG, &result[i], sizeof(T));
|
||||
}
|
||||
for (int i = 1; i < size_; i++) {
|
||||
sockets_[i - 1].send(IBV_TAG, result.data(), size_ * sizeof(T));
|
||||
}
|
||||
} else {
|
||||
sockets_[0].send(IBV_TAG, &v, sizeof(T));
|
||||
sockets_[0].recv(IBV_TAG, result.data(), size_ * sizeof(T));
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
int rank_;
|
||||
int size_;
|
||||
std::vector<detail::TCPSocket> sockets_;
|
||||
};
|
||||
|
||||
/**
|
||||
* Manages a set of connections. Among other things it uses a side channel to
|
||||
* exchange the necessary information and then configure the connections to be
|
||||
* ready for RDMA operations.
|
||||
*/
|
||||
class ConnectionManager {
|
||||
public:
|
||||
ConnectionManager(
|
||||
int rank,
|
||||
const std::vector<std::string>& device_names,
|
||||
const char* coordinator_addr)
|
||||
: rank_(rank),
|
||||
size_(device_names.size()),
|
||||
side_channel_(rank_, size_, coordinator_addr) {
|
||||
create_contexts(device_names);
|
||||
if (connections_[rank_].ctx != nullptr) {
|
||||
throw std::runtime_error("[ibv] Malformed device file");
|
||||
}
|
||||
}
|
||||
|
||||
int rank() const {
|
||||
return rank_;
|
||||
}
|
||||
|
||||
int size() const {
|
||||
return size_;
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs the connection initialization. Namely, after this call all
|
||||
* Connection objects should have a queue pair in RTS state.
|
||||
*/
|
||||
void initialize(int num_buffers, size_t num_bytes) {
|
||||
// Create the queue pairs
|
||||
for (auto& conn : connections_) {
|
||||
if (conn.ctx == nullptr) {
|
||||
continue;
|
||||
}
|
||||
conn.allocate_protection_domain();
|
||||
conn.create_completion_queue(MAX_SEND_WR + MAX_RECV_WR);
|
||||
conn.create_queue_pair();
|
||||
}
|
||||
|
||||
allocate_buffers(num_buffers, num_bytes);
|
||||
|
||||
// Gather the information to be exchanged
|
||||
std::vector<Destination> info;
|
||||
for (auto& conn : connections_) {
|
||||
info.emplace_back(conn.info());
|
||||
}
|
||||
auto all_infos = side_channel_.all_gather(info);
|
||||
|
||||
// Transition queue pairs to RTS
|
||||
for (int peer = 0; peer < size_; peer++) {
|
||||
if (peer == rank_) {
|
||||
continue;
|
||||
}
|
||||
auto peer_info = all_infos[peer][rank_];
|
||||
connections_[peer].queue_pair_init();
|
||||
connections_[peer].queue_pair_rtr(peer_info);
|
||||
connections_[peer].queue_pair_rts();
|
||||
}
|
||||
}
|
||||
|
||||
void allocate_buffers(int num_buffers, size_t num_bytes) {
|
||||
// Deregister any buffers and free the memory
|
||||
buffers_.clear();
|
||||
|
||||
// Allocate the memory
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
for (int j = 0; j < size_; j++) {
|
||||
buffers_.emplace_back(num_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
for (int j = 0; j < size_; j++) {
|
||||
// This is our send buffer so register it with all pds so we can send
|
||||
// it to all connected devices.
|
||||
if (j == rank_) {
|
||||
for (auto& conn : connections_) {
|
||||
if (conn.ctx != nullptr) {
|
||||
buffers_[i * size_ + j].register_to_protection_domain(
|
||||
conn.protection_domain);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This is the recv buffer from rank j so register it to rank j's
|
||||
// protection domain.
|
||||
else {
|
||||
buffers_[i * size_ + j].register_to_protection_domain(
|
||||
connections_[j].protection_domain);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void send_to(int rank, int buff) {
|
||||
connections_[rank].post_send(
|
||||
buffers_[buff * size_ + rank_],
|
||||
SEND_WR << 16 | buff << 8 | rank);
|
||||
}
|
||||
|
||||
void recv_from(int rank, int buff) {
|
||||
connections_[rank].post_recv(
|
||||
buffers_[buff * size_ + rank], RECV_WR << 16 | buff << 8 | rank);
|
||||
}
|
||||
|
||||
/**
|
||||
* Poll all connections and save the work completions and return the
|
||||
* corresponding length.
|
||||
*/
|
||||
int poll(int num_completions, ibv_wc* work_completions) {
|
||||
int completions = 0;
|
||||
for (int r = 0; r < size_; r++) {
|
||||
if (r == rank_) {
|
||||
continue;
|
||||
}
|
||||
if (completions >= num_completions) {
|
||||
return completions;
|
||||
}
|
||||
|
||||
int c = ibv_poll_cq(
|
||||
connections_[r].completion_queue,
|
||||
num_completions - completions,
|
||||
work_completions + completions);
|
||||
|
||||
completions += c;
|
||||
}
|
||||
return completions;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
int poll(int rank, int num_completions, ibv_wc* work_completions) {
|
||||
return ibv_poll_cq(
|
||||
connections_[rank].completion_queue,
|
||||
num_completions,
|
||||
work_completions);
|
||||
}
|
||||
|
||||
SharedBuffer& send_buffer(int buff) {
|
||||
return buffers_[buff * size_ + rank_];
|
||||
}
|
||||
|
||||
SharedBuffer& buffer(int rank, int buff) {
|
||||
return buffers_[buff * size_ + rank];
|
||||
}
|
||||
|
||||
void barrier() {
|
||||
side_channel_.all_gather<int>(0);
|
||||
}
|
||||
|
||||
private:
|
||||
void create_contexts(const std::vector<std::string>& device_names) {
|
||||
int num_devices = 0;
|
||||
ibv_device** devices = ibv_get_device_list(&num_devices);
|
||||
for (auto& name : device_names) {
|
||||
// Empty so add a nullptr context
|
||||
if (name.empty()) {
|
||||
connections_.emplace_back(nullptr);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Search for the name and try to open the device
|
||||
for (int i = 0; i < num_devices; i++) {
|
||||
if (name == ibv_get_device_name(devices[i])) {
|
||||
auto ctx = ibv_open_device(devices[i]);
|
||||
if (ctx == nullptr) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ibv] Could not open device " << name;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
connections_.emplace_back(ctx);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
ibv_free_device_list(devices);
|
||||
}
|
||||
|
||||
int rank_;
|
||||
int size_;
|
||||
SideChannel side_channel_;
|
||||
std::vector<Connection> connections_;
|
||||
std::vector<SharedBuffer> buffers_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct SumOp {
|
||||
void operator()(const T* input, T* output, size_t N) const {
|
||||
while (N-- > 0) {
|
||||
*output += *input;
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<std::string> load_device_names(int rank, const char* dev_file) {
|
||||
std::vector<std::string> device_names;
|
||||
std::ifstream f(dev_file);
|
||||
|
||||
json devices = json::parse(f);
|
||||
devices = devices[rank];
|
||||
for (auto it = devices.begin(); it != devices.end(); it++) {
|
||||
std::string n;
|
||||
if (!it->is_null()) {
|
||||
n = *it;
|
||||
}
|
||||
device_names.emplace_back(std::move(n));
|
||||
}
|
||||
|
||||
return device_names;
|
||||
}
|
||||
|
||||
namespace mlx::core::distributed::ibv {
|
||||
|
||||
class IBVGroup : public GroupImpl {
|
||||
public:
|
||||
IBVGroup(ConnectionManager cm) : cm_(std::move(cm)), rank_(cm.rank()), size_(cm.size()) {}
|
||||
|
||||
Stream communication_stream(StreamOrDevice s) override {
|
||||
return to_stream(s, Device::cpu);
|
||||
}
|
||||
|
||||
int rank() override {
|
||||
return cm_.rank();
|
||||
}
|
||||
|
||||
int size() override {
|
||||
return cm_.size();
|
||||
}
|
||||
|
||||
void all_sum(const array& input, array& output, Stream stream) override {
|
||||
dispatch_all_types(output.dtype(), [&](auto type_tag) {
|
||||
using T = MLX_GET_TYPE(type_tag);
|
||||
all_reduce<T>(input, output, stream, SumOp<T>{});
|
||||
});
|
||||
}
|
||||
|
||||
void all_max(const array& input, array& output, Stream stream) override {
|
||||
}
|
||||
|
||||
void all_min(const array& input, array& output, Stream stream) override {
|
||||
}
|
||||
|
||||
void all_gather(const array& input, array& output, Stream stream) override {
|
||||
}
|
||||
|
||||
void send(const array& input, int dst, Stream stream) override {
|
||||
}
|
||||
|
||||
void recv(array& out, int src, Stream stream) override {
|
||||
}
|
||||
|
||||
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||
throw std::runtime_error("[ibv] Group split not supported.");
|
||||
}
|
||||
|
||||
private:
|
||||
void post_recv_all(int buffer) {
|
||||
for (int i = 0; i < size_; i++) {
|
||||
if (i == rank_) {
|
||||
continue;
|
||||
}
|
||||
cm_.recv_from(i, buffer);
|
||||
}
|
||||
}
|
||||
|
||||
void post_send_all(int buffer) {
|
||||
for (int i = 0; i < size_; i++) {
|
||||
if (i == rank_) {
|
||||
continue;
|
||||
}
|
||||
cm_.send_to(i, buffer);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename ReduceOp>
|
||||
void all_reduce(
|
||||
const array& input,
|
||||
array& output,
|
||||
Stream stream,
|
||||
ReduceOp reduce_op) {
|
||||
auto in_ptr = input.data<T>();
|
||||
auto out_ptr = output.data<T>();
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_output_array(output);
|
||||
encoder.dispatch([in_ptr, out_ptr, size = input.size(), this, reduce_op]() {
|
||||
// If not inplace all reduce then copy the input to the output first
|
||||
if (in_ptr != out_ptr) {
|
||||
std::memcpy(out_ptr, in_ptr, size * sizeof(T));
|
||||
}
|
||||
|
||||
// Fully connected all reduce
|
||||
T* data = out_ptr;
|
||||
constexpr int64_t N = BUFFER_SIZE / sizeof(T);
|
||||
int64_t total = static_cast<int64_t>(size);
|
||||
int64_t offset = N;
|
||||
int a = 0, b = 1;
|
||||
|
||||
int mask_init = 1 << rank_;
|
||||
int mask_target = (1 << size_) - 1;
|
||||
|
||||
// Handle the first piece of data.
|
||||
post_recv_all(a);
|
||||
std::copy(
|
||||
data,
|
||||
data + std::min(N, total),
|
||||
cm_.send_buffer(a).begin<T>());
|
||||
post_send_all(a);
|
||||
|
||||
int mask_a_send = mask_init;
|
||||
int mask_a_recv = mask_init;
|
||||
int mask_b_send = mask_init;
|
||||
int mask_b_recv = mask_init;
|
||||
|
||||
while (offset < total) {
|
||||
// While the previous chunk is in flight copy to the next send buffer
|
||||
std::copy(
|
||||
data + offset,
|
||||
data + std::min(offset + N, total),
|
||||
cm_.send_buffer(b).begin<T>());
|
||||
|
||||
// Send if the previous send is already done
|
||||
for (int i = 0, m = 1; i < size_; i++, m *= 2) {
|
||||
if (i == rank_) {
|
||||
continue;
|
||||
}
|
||||
if (mask_a_send & m) {
|
||||
cm_.send_to(i, b);
|
||||
}
|
||||
}
|
||||
|
||||
// Recv the next buffer if the previous one is already done
|
||||
for (int i = 0, m = 1; i < size_; i++, m *= 2) {
|
||||
if (i == rank_) {
|
||||
continue;
|
||||
}
|
||||
if (mask_a_recv & m) {
|
||||
cm_.recv_from(i, b);
|
||||
reduce_op(
|
||||
cm_.buffer(i, a).begin<T>(),
|
||||
data + std::max(offset - N, 0LL),
|
||||
std::min(N, total - offset));
|
||||
}
|
||||
}
|
||||
|
||||
// Loop until this chunk is all done
|
||||
while (mask_a_send != mask_target || mask_a_recv != mask_target) {
|
||||
ibv_wc wc[8];
|
||||
int n = cm_.poll(8, wc);
|
||||
for (int i = 0; i < n; i++) {
|
||||
int work_type = wc[i].wr_id >> 16;
|
||||
int buff = (wc[i].wr_id >> 8) & 0xff;
|
||||
int rank = wc[i].wr_id & 0xff;
|
||||
|
||||
if (work_type == SEND_WR) {
|
||||
if (buff == a) {
|
||||
cm_.send_to(rank, b);
|
||||
mask_a_send |= 1 << rank;
|
||||
} else {
|
||||
mask_b_send |= 1 << rank;
|
||||
}
|
||||
} else {
|
||||
if (buff == a) {
|
||||
cm_.recv_from(rank, b);
|
||||
mask_a_recv |= 1 << rank;
|
||||
reduce_op(
|
||||
cm_.buffer(rank, a).begin<T>(),
|
||||
data + std::max(offset - N, 0LL),
|
||||
std::min(N, total - offset));
|
||||
} else {
|
||||
mask_b_recv |= 1 << rank;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::swap(a, b);
|
||||
mask_a_send = mask_b_send;
|
||||
mask_a_recv = mask_b_recv;
|
||||
mask_b_send = mask_init;
|
||||
mask_b_recv = mask_init;
|
||||
offset += N;
|
||||
}
|
||||
|
||||
{
|
||||
while (mask_a_send != mask_target || mask_a_recv != mask_target) {
|
||||
ibv_wc wc[8];
|
||||
int n = cm_.poll(8, wc);
|
||||
for (int i = 0; i < n; i++) {
|
||||
int work_type = wc[i].wr_id >> 16;
|
||||
int buff = (wc[i].wr_id >> 8) & 0xff;
|
||||
int rank = wc[i].wr_id & 0xff;
|
||||
|
||||
if (work_type == SEND_WR) {
|
||||
mask_a_send |= 1 << rank;
|
||||
} else {
|
||||
mask_a_recv |= 1 << rank;
|
||||
reduce_op(
|
||||
cm_.buffer(rank, a).begin<T>(),
|
||||
data + offset - N,
|
||||
std::min(N, total + N - offset));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
ConnectionManager cm_;
|
||||
int rank_;
|
||||
int size_;
|
||||
};
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||
const char* dev_file = std::getenv("MLX_IBV_DEVICES");
|
||||
const char* coordinator = std::getenv("MLX_IBV_COORDINATOR");
|
||||
const char* rank_str = std::getenv("MLX_RANK");
|
||||
const char* ring_verbose = std::getenv("MLX_IBV_VERBOSE");
|
||||
|
||||
if (!dev_file || !coordinator || !rank_str) {
|
||||
if (strict) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ibv] You need to provide via environment variables a rank (MLX_RANK), "
|
||||
<< "a device file (MLX_IBV_DEVICES) and a coordinator ip/port (MLX_IBV_COORDINATOR) "
|
||||
<< "but provided MLX_RANK=\"" << ((rank_str) ? rank_str : "")
|
||||
<< "\", MLX_IBV_DEVICES=\"" << ((dev_file) ? dev_file : "")
|
||||
<< "\" and MLX_IBV_COORDINATOR=\""
|
||||
<< ((coordinator) ? coordinator : "");
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto rank = std::atoi(rank_str);
|
||||
auto device_names = load_device_names(rank, dev_file);
|
||||
|
||||
auto cm = ConnectionManager(rank, device_names, coordinator);
|
||||
if (cm.size() > MAX_PEERS) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ibv] The maximum number of supported peers is "
|
||||
<< MAX_PEERS << " but " << cm.size() << " was provided";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
cm.initialize(NUM_BUFFERS, BUFFER_SIZE);
|
||||
cm.barrier();
|
||||
|
||||
return std::make_shared<IBVGroup>(std::move(cm));
|
||||
|
||||
//cm.recv_from(rank ^ 1, 0);
|
||||
//cm.barrier();
|
||||
//std::fill(
|
||||
// cm.buffer(rank, 0).begin<int>(),
|
||||
// cm.buffer(rank, 0).end<int>(),
|
||||
// 2 * rank + 1);
|
||||
//cm.barrier();
|
||||
//cm.send_to(rank ^ 1, 0);
|
||||
|
||||
//ibv_wc wc[8];
|
||||
//bool sent = false, recvd = false;
|
||||
//while (!(sent && recvd)) {
|
||||
// int num_completions = cm.poll(8, wc);
|
||||
// for (int i = 0; i < num_completions; i++) {
|
||||
// if (wc[i].status != IBV_WC_SUCCESS) {
|
||||
// std::cout << "Error " << wc[i].status << std::endl;
|
||||
// }
|
||||
|
||||
// int work_type = wc[i].wr_id >> 16;
|
||||
// int buff = (wc[i].wr_id >> 8) & 0xff;
|
||||
// int rank = wc[i].wr_id & 0xff;
|
||||
|
||||
// std::cout << work_type << " " << buff << " " << rank << std::endl;
|
||||
// print_wc(wc[i]);
|
||||
|
||||
// sent |= (work_type == SEND_WR);
|
||||
// recvd |= (work_type == RECV_WR);
|
||||
// }
|
||||
//}
|
||||
|
||||
//std::cout << rank << " ours " << *cm.buffer(rank, 0).data<int>() << std::endl;
|
||||
//std::cout << rank << " theirs " << *cm.buffer(rank ^ 1, 0).data<int>()
|
||||
// << std::endl;
|
||||
|
||||
//return nullptr;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed::ibv
|
||||
12
mlx/distributed/ibv/ibv.h
Normal file
12
mlx/distributed/ibv/ibv.h
Normal file
@@ -0,0 +1,12 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
|
||||
namespace mlx::core::distributed::ibv {
|
||||
|
||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||
|
||||
bool is_available();
|
||||
std::shared_ptr<GroupImpl> init(bool strict = false);
|
||||
|
||||
} // namespace mlx::core::distributed::ibv
|
||||
@@ -1,9 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <arpa/inet.h>
|
||||
#include <fcntl.h>
|
||||
#include <netdb.h>
|
||||
#include <netinet/in.h>
|
||||
#include <netinet/tcp.h>
|
||||
#include <sys/socket.h>
|
||||
#include <unistd.h>
|
||||
@@ -22,6 +19,7 @@
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/distributed/utils.h"
|
||||
#include "mlx/threadpool.h"
|
||||
|
||||
#ifndef SOL_TCP
|
||||
@@ -94,6 +92,7 @@ constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024;
|
||||
constexpr const size_t ALL_SUM_BUFFERS = 2;
|
||||
constexpr const int CONN_ATTEMPTS = 5;
|
||||
constexpr const int CONN_WAIT = 1000;
|
||||
constexpr const char* RING_TAG = "[ring]";
|
||||
|
||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||
using json = nlohmann::json;
|
||||
@@ -296,55 +295,6 @@ class CommunicationThreads {
|
||||
std::unordered_map<int, SocketThread> threads_;
|
||||
};
|
||||
|
||||
struct address_t {
|
||||
sockaddr_storage addr;
|
||||
socklen_t len;
|
||||
|
||||
const sockaddr* get() const {
|
||||
return (struct sockaddr*)&addr;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Parse a sockaddr from an ip and port provided as strings.
|
||||
*/
|
||||
address_t parse_address(const std::string& ip, const std::string& port) {
|
||||
struct addrinfo hints, *res;
|
||||
memset(&hints, 0, sizeof(hints));
|
||||
hints.ai_family = AF_UNSPEC;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
|
||||
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
|
||||
if (status != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Can't parse address " << ip << ":" << port;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
address_t result;
|
||||
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
|
||||
result.len = res->ai_addrlen;
|
||||
freeaddrinfo(res);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse a sockaddr provided as an <ip>:<port> string.
|
||||
*/
|
||||
address_t parse_address(const std::string& ip_port) {
|
||||
auto colon = ip_port.find(":");
|
||||
if (colon == std::string::npos) {
|
||||
std::ostringstream msg;
|
||||
msg << "Can't parse address " << ip_port;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
std::string ip(ip_port.begin(), ip_port.begin() + colon);
|
||||
std::string port(ip_port.begin() + colon + 1, ip_port.end());
|
||||
|
||||
return parse_address(ip, port);
|
||||
}
|
||||
|
||||
/**
|
||||
* Load all addresses from the json hostfile. The hostfile is a list of
|
||||
* addresses in order of rank. For each rank there can be many addresses so
|
||||
@@ -357,15 +307,15 @@ address_t parse_address(const std::string& ip_port) {
|
||||
* ["ip3:5000", "ip3:5001"],
|
||||
* ]
|
||||
*/
|
||||
std::vector<std::vector<address_t>> load_nodes(const char* hostfile) {
|
||||
std::vector<std::vector<address_t>> nodes;
|
||||
std::vector<std::vector<detail::address_t>> load_nodes(const char* hostfile) {
|
||||
std::vector<std::vector<detail::address_t>> nodes;
|
||||
std::ifstream f(hostfile);
|
||||
|
||||
json hosts = json::parse(f);
|
||||
for (auto& h : hosts) {
|
||||
std::vector<address_t> host;
|
||||
std::vector<detail::address_t> host;
|
||||
for (auto& ips : h) {
|
||||
host.push_back(parse_address(ips.get<std::string>()));
|
||||
host.push_back(std::move(detail::parse_address(ips.get<std::string>())));
|
||||
}
|
||||
nodes.push_back(std::move(host));
|
||||
}
|
||||
@@ -377,73 +327,15 @@ std::vector<std::vector<address_t>> load_nodes(const char* hostfile) {
|
||||
* Create a socket and accept one connection for each of the provided
|
||||
* addresses.
|
||||
*/
|
||||
std::vector<int> accept_connections(const std::vector<address_t>& addresses) {
|
||||
std::vector<int> accept_connections(
|
||||
const std::vector<detail::address_t>& addresses) {
|
||||
std::vector<int> sockets;
|
||||
int success;
|
||||
|
||||
for (auto& address : addresses) {
|
||||
// Create the socket to wait for connections from the peers
|
||||
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sock < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't create socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Make sure we can launch immediately after shutdown by setting the
|
||||
// reuseaddr option so that we don't get address already in use errors
|
||||
int enable = 1;
|
||||
success = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't enable reuseaddr (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
success = setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't enable reuseport (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Bind the socket to the address and port
|
||||
success = bind(sock, address.get(), address.len);
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't bind socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Wait for connections
|
||||
success = listen(sock, 0);
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't listen (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
int peer_socket = accept(sock, nullptr, nullptr);
|
||||
if (peer_socket < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Accept failed (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Close the listening socket
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
|
||||
sockets.push_back(peer_socket);
|
||||
detail::TCPSocket socket(RING_TAG);
|
||||
socket.listen(RING_TAG, address);
|
||||
sockets.push_back(socket.accept(RING_TAG));
|
||||
}
|
||||
|
||||
return sockets;
|
||||
@@ -454,55 +346,33 @@ std::vector<int> accept_connections(const std::vector<address_t>& addresses) {
|
||||
* provided addresses.
|
||||
*/
|
||||
std::vector<int> make_connections(
|
||||
const std::vector<address_t>& addresses,
|
||||
const std::vector<detail::address_t>& addresses,
|
||||
bool verbose) {
|
||||
std::vector<int> sockets;
|
||||
int success;
|
||||
|
||||
for (auto& address : addresses) {
|
||||
int sock;
|
||||
|
||||
// Attempt to connect to the peer CONN_ATTEMPTS times with exponential
|
||||
// backoff. TODO: Do we need that?
|
||||
for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) {
|
||||
// Create the socket
|
||||
sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sock < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't create socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
if (attempt > 0) {
|
||||
int wait = (1 << (attempt - 1)) * CONN_WAIT;
|
||||
log_info(
|
||||
verbose,
|
||||
"Attempt",
|
||||
attempt,
|
||||
"wait",
|
||||
wait,
|
||||
"ms (error:",
|
||||
errno,
|
||||
")");
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
|
||||
}
|
||||
|
||||
success = connect(sock, address.get(), address.len);
|
||||
if (success == 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (success < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Couldn't connect (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
sockets.push_back(sock);
|
||||
sockets.push_back(detail::TCPSocket::connect(
|
||||
RING_TAG,
|
||||
address,
|
||||
CONN_ATTEMPTS,
|
||||
CONN_WAIT,
|
||||
[verbose](int attempt, int wait) {
|
||||
log_info(
|
||||
verbose,
|
||||
"Attempt",
|
||||
attempt,
|
||||
"waiting",
|
||||
wait,
|
||||
"ms (error:",
|
||||
errno,
|
||||
")");
|
||||
}));
|
||||
}
|
||||
|
||||
return sockets;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct SumOp {
|
||||
void operator()(const T* input, T* output, size_t N) {
|
||||
@@ -540,7 +410,10 @@ struct MinOp {
|
||||
|
||||
class RingGroup : public GroupImpl {
|
||||
public:
|
||||
RingGroup(int rank, std::vector<std::vector<address_t>> nodes, bool verbose)
|
||||
RingGroup(
|
||||
int rank,
|
||||
std::vector<std::vector<detail::address_t>> nodes,
|
||||
bool verbose)
|
||||
: rank_(rank), verbose_(verbose), pool_(0) {
|
||||
if (rank_ > 0 && rank_ >= nodes.size()) {
|
||||
throw std::runtime_error(
|
||||
|
||||
189
mlx/distributed/utils.cpp
Normal file
189
mlx/distributed/utils.cpp
Normal file
@@ -0,0 +1,189 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <netdb.h>
|
||||
#include <unistd.h>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
|
||||
#include "mlx/distributed/utils.h"
|
||||
|
||||
namespace mlx::core::distributed::detail {
|
||||
|
||||
/**
|
||||
* Parse a sockaddr from an ip and port provided as strings.
|
||||
*/
|
||||
address_t parse_address(const std::string& ip, const std::string& port) {
|
||||
struct addrinfo hints, *res;
|
||||
memset(&hints, 0, sizeof(hints));
|
||||
hints.ai_family = AF_UNSPEC;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
|
||||
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
|
||||
if (status != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Can't parse address " << ip << ":" << port;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
address_t result;
|
||||
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
|
||||
result.len = res->ai_addrlen;
|
||||
freeaddrinfo(res);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse a sockaddr provided as an <ip>:<port> string.
|
||||
*/
|
||||
address_t parse_address(const std::string& ip_port) {
|
||||
auto colon = ip_port.find(":");
|
||||
if (colon == std::string::npos) {
|
||||
std::ostringstream msg;
|
||||
msg << "Can't parse address " << ip_port;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
std::string ip(ip_port.begin(), ip_port.begin() + colon);
|
||||
std::string port(ip_port.begin() + colon + 1, ip_port.end());
|
||||
|
||||
return parse_address(ip, port);
|
||||
}
|
||||
|
||||
TCPSocket::TCPSocket(const char* tag) {
|
||||
sock_ = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sock_ < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Couldn't create socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
TCPSocket::TCPSocket(TCPSocket&& s) {
|
||||
sock_ = s.sock_;
|
||||
s.sock_ = -1;
|
||||
}
|
||||
|
||||
TCPSocket::TCPSocket(int s) : sock_(s) {}
|
||||
|
||||
TCPSocket::~TCPSocket() {
|
||||
if (sock_ > 0) {
|
||||
shutdown(sock_, 2);
|
||||
close(sock_);
|
||||
}
|
||||
}
|
||||
|
||||
void TCPSocket::listen(const char* tag, const address_t& addr) {
|
||||
int success;
|
||||
|
||||
// Make sure we can launch immediately after shutdown by setting the
|
||||
// reuseaddr option so that we don't get address already in use errors
|
||||
int enable = 1;
|
||||
success = setsockopt(sock_, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
|
||||
if (success < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Couldn't enable reuseaddr (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
success = setsockopt(sock_, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
|
||||
if (success < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Couldn't enable reuseport (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Bind the socket to the address and port
|
||||
success = bind(sock_, addr.get(), addr.len);
|
||||
if (success < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Couldn't bind socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Prepare waiting for connections
|
||||
success = ::listen(sock_, 0);
|
||||
if (success < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Couldn't listen (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
TCPSocket TCPSocket::accept(const char* tag) {
|
||||
int peer = ::accept(sock_, nullptr, nullptr);
|
||||
if (peer < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Accept failed (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return TCPSocket(peer);
|
||||
}
|
||||
|
||||
void TCPSocket::send(const char* tag, const void* data, size_t len) {
|
||||
while (len > 0) {
|
||||
auto n = ::send(sock_, data, len, 0);
|
||||
if (n <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Send failed with errno=" << errno;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
len -= n;
|
||||
data = static_cast<const char*>(data) + n;
|
||||
}
|
||||
}
|
||||
|
||||
void TCPSocket::recv(const char* tag, void* data, size_t len) {
|
||||
while (len > 0) {
|
||||
auto n = ::recv(sock_, data, len, 0);
|
||||
if (n <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Recv failed with errno=" << errno;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
len -= n;
|
||||
data = static_cast<char*>(data) + n;
|
||||
}
|
||||
}
|
||||
|
||||
TCPSocket TCPSocket::connect(
|
||||
const char* tag,
|
||||
const address_t& addr,
|
||||
int num_retries,
|
||||
int wait,
|
||||
std::function<void(int, int)> cb) {
|
||||
int sock, success;
|
||||
|
||||
// Attempt to connect `num_retries` times with exponential backoff.
|
||||
for (int attempt = 0; attempt < num_retries; attempt++) {
|
||||
// Create the socket
|
||||
sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sock < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Couldn't create socket to connect (error: " << errno
|
||||
<< ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
success = ::connect(sock, addr.get(), addr.len);
|
||||
if (success == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
cb(attempt, wait);
|
||||
if (wait > 0) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
|
||||
}
|
||||
|
||||
wait <<= 1;
|
||||
}
|
||||
|
||||
if (success < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Couldn't connect (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return TCPSocket(sock);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed::detail
|
||||
62
mlx/distributed/utils.h
Normal file
62
mlx/distributed/utils.h
Normal file
@@ -0,0 +1,62 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <sys/socket.h>
|
||||
|
||||
namespace mlx::core::distributed::detail {
|
||||
|
||||
struct address_t {
|
||||
sockaddr_storage addr;
|
||||
socklen_t len;
|
||||
|
||||
const sockaddr* get() const {
|
||||
return (struct sockaddr*)&addr;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Parse a sockaddr from an ip and port provided as strings.
|
||||
*/
|
||||
address_t parse_address(const std::string& ip, const std::string& port);
|
||||
|
||||
/**
|
||||
* Parse a sockaddr provided as an <ip>:<port> string.
|
||||
*/
|
||||
address_t parse_address(const std::string& ip_port);
|
||||
|
||||
/**
|
||||
* Small wrapper over a TCP socket to simplify initiating connections.
|
||||
*/
|
||||
class TCPSocket {
|
||||
public:
|
||||
TCPSocket(const char* tag);
|
||||
TCPSocket(const TCPSocket&) = delete;
|
||||
TCPSocket& operator=(const TCPSocket&) = delete;
|
||||
TCPSocket(TCPSocket&& s);
|
||||
~TCPSocket();
|
||||
|
||||
void listen(const char* tag, const address_t& addr);
|
||||
TCPSocket accept(const char* tag);
|
||||
|
||||
void send(const char* tag, const void* data, size_t len);
|
||||
void recv(const char* tag, void* data, size_t len);
|
||||
|
||||
operator int() const {
|
||||
return sock_;
|
||||
}
|
||||
|
||||
static TCPSocket connect(
|
||||
const char* tag,
|
||||
const address_t& addr,
|
||||
int num_retries = 1,
|
||||
int wait = 0,
|
||||
std::function<void(int, int)> cb = nullptr);
|
||||
|
||||
private:
|
||||
TCPSocket(int sock);
|
||||
|
||||
int sock_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::distributed::detail
|
||||
Reference in New Issue
Block a user