mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Initial working all reduce
This commit is contained in:
@@ -119,6 +119,10 @@ if(MLX_BUILD_METAL)
|
|||||||
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||||
OUTPUT_VARIABLE MACOS_SDK_VERSION
|
OUTPUT_VARIABLE MACOS_SDK_VERSION
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||||
|
execute_process(
|
||||||
|
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-path"
|
||||||
|
OUTPUT_VARIABLE CMAKE_OSX_SYSROOT
|
||||||
|
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||||
|
|
||||||
if(${MACOS_SDK_VERSION} LESS 14.0)
|
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||||
message(
|
message(
|
||||||
|
|||||||
@@ -4,6 +4,11 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp)
|
||||||
|
|
||||||
|
if(MLX_BUILD_CPU AND NOT WIN32)
|
||||||
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
|
||||||
|
endif()
|
||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ibv)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
#include "mlx/backend/cuda/cuda.h"
|
#include "mlx/backend/cuda/cuda.h"
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/distributed_impl.h"
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
|
#include "mlx/distributed/ibv/ibv.h"
|
||||||
#include "mlx/distributed/mpi/mpi.h"
|
#include "mlx/distributed/mpi/mpi.h"
|
||||||
#include "mlx/distributed/nccl/nccl.h"
|
#include "mlx/distributed/nccl/nccl.h"
|
||||||
#include "mlx/distributed/ring/ring.h"
|
#include "mlx/distributed/ring/ring.h"
|
||||||
@@ -102,7 +103,8 @@ class EmptyGroup : public GroupImpl {
|
|||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
return mpi::is_available() || ring::is_available() || nccl::is_available();
|
return mpi::is_available() || ring::is_available() || nccl::is_available() ||
|
||||||
|
ibv::is_available();
|
||||||
}
|
}
|
||||||
|
|
||||||
int Group::rank() const {
|
int Group::rank() const {
|
||||||
@@ -135,6 +137,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
|||||||
group = ring::init(strict);
|
group = ring::init(strict);
|
||||||
} else if (bk == "nccl") {
|
} else if (bk == "nccl") {
|
||||||
group = nccl::init(strict);
|
group = nccl::init(strict);
|
||||||
|
} else if (bk == "ibv") {
|
||||||
|
group = ibv::init(strict);
|
||||||
} else if (bk == "any") {
|
} else if (bk == "any") {
|
||||||
if (mlx::core::cu::is_available()) {
|
if (mlx::core::cu::is_available()) {
|
||||||
group = nccl::init(false);
|
group = nccl::init(false);
|
||||||
@@ -148,13 +152,17 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
|||||||
group = mpi::init(false);
|
group = mpi::init(false);
|
||||||
bk_ = "mpi";
|
bk_ = "mpi";
|
||||||
}
|
}
|
||||||
|
if (group == nullptr) {
|
||||||
|
group = ibv::init(false);
|
||||||
|
bk_ = "ibv";
|
||||||
|
}
|
||||||
if (group == nullptr && strict) {
|
if (group == nullptr && strict) {
|
||||||
throw std::runtime_error("[distributed] Couldn't initialize any backend");
|
throw std::runtime_error("[distributed] Couldn't initialize any backend");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[distributed] The only valid values for backend are 'any', 'mpi' "
|
msg << "[distributed] The only valid values for backend are 'any', 'mpi', 'nccl', "
|
||||||
<< "and 'ring' but '" << bk << "' was provided.";
|
<< "'ibv' and 'ring' but '" << bk << "' was provided.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
2
mlx/distributed/ibv/CMakeLists.txt
Normal file
2
mlx/distributed/ibv/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ibv.cpp)
|
||||||
|
target_link_libraries(mlx PRIVATE rdma)
|
||||||
955
mlx/distributed/ibv/ibv.cpp
Normal file
955
mlx/distributed/ibv/ibv.cpp
Normal file
@@ -0,0 +1,955 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <infiniband/verbs.h>
|
||||||
|
#include <fstream>
|
||||||
|
#include <iostream>
|
||||||
|
#include <unistd.h>
|
||||||
|
|
||||||
|
#include <json.hpp>
|
||||||
|
|
||||||
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
|
#include "mlx/distributed/utils.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
|
constexpr const char* IBV_TAG = "[ibv]";
|
||||||
|
constexpr int NUM_BUFFERS = 2;
|
||||||
|
constexpr int BUFFER_SIZE = 4096;
|
||||||
|
constexpr int MAX_SEND_WR = 32;
|
||||||
|
constexpr int MAX_RECV_WR = 32;
|
||||||
|
constexpr int SEND_WR = 1;
|
||||||
|
constexpr int RECV_WR = 2;
|
||||||
|
constexpr int MAX_PEERS = 8;
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
using json = nlohmann::json;
|
||||||
|
namespace detail = mlx::core::distributed::detail;
|
||||||
|
namespace allocator = mlx::core::allocator;
|
||||||
|
|
||||||
|
template <typename T, typename = void>
|
||||||
|
struct is_container : std::false_type {};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct is_container<
|
||||||
|
T,
|
||||||
|
std::void_t<typename T::value_type, typename T::iterator>>
|
||||||
|
: std::true_type {};
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& os, const ibv_gid& gid) {
|
||||||
|
os << std::hex << std::setfill('0');
|
||||||
|
for (int i = 0; i < 16; i += 2) {
|
||||||
|
uint16_t part = (gid.raw[i] << 8) | gid.raw[i + 1];
|
||||||
|
os << std::setw(4) << part;
|
||||||
|
if (i < 14)
|
||||||
|
os << ":";
|
||||||
|
}
|
||||||
|
os << std::dec;
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
void* page_aligned_alloc(size_t num_bytes) {
|
||||||
|
static size_t page_size = sysconf(_SC_PAGESIZE);
|
||||||
|
void * buf;
|
||||||
|
if (posix_memalign(&buf, page_size, num_bytes)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return buf;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Contains the information that defines a destination to a remote device.
|
||||||
|
* Basically we can compute our own destination and share it with remote hosts
|
||||||
|
* over the side channel.
|
||||||
|
*/
|
||||||
|
struct Destination {
|
||||||
|
int local_id;
|
||||||
|
int queue_pair_number;
|
||||||
|
int packet_sequence_number;
|
||||||
|
ibv_gid global_identifier;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A buffer that can be registered to a number of protection domains.
|
||||||
|
*/
|
||||||
|
class SharedBuffer {
|
||||||
|
public:
|
||||||
|
SharedBuffer(size_t num_bytes)
|
||||||
|
: data_(page_aligned_alloc(num_bytes)),
|
||||||
|
num_bytes_(num_bytes) {}
|
||||||
|
~SharedBuffer() {
|
||||||
|
for (auto& [pd, mr] : memory_regions_) {
|
||||||
|
ibv_dereg_mr(mr);
|
||||||
|
}
|
||||||
|
if (data_ != nullptr) {
|
||||||
|
std::free(data_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SharedBuffer(const SharedBuffer&) = delete;
|
||||||
|
SharedBuffer& operator=(const SharedBuffer&) = delete;
|
||||||
|
SharedBuffer(SharedBuffer&& b)
|
||||||
|
: data_(nullptr), num_bytes_(0) {
|
||||||
|
std::swap(data_, b.data_);
|
||||||
|
std::swap(num_bytes_, b.num_bytes_);
|
||||||
|
std::swap(memory_regions_, b.memory_regions_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void register_to_protection_domain(ibv_pd* protection_domain) {
|
||||||
|
auto [it, inserted] = memory_regions_.insert({protection_domain, nullptr});
|
||||||
|
if (!inserted) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[ibv] Buffer can be registered once per protection domain");
|
||||||
|
}
|
||||||
|
|
||||||
|
it->second = ibv_reg_mr(
|
||||||
|
protection_domain,
|
||||||
|
data_,
|
||||||
|
num_bytes_,
|
||||||
|
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ |
|
||||||
|
IBV_ACCESS_REMOTE_WRITE);
|
||||||
|
if (!it->second) {
|
||||||
|
throw std::runtime_error("[ibv] Register memory region failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t size() const {
|
||||||
|
return num_bytes_;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t local_key(ibv_pd* protection_domain) const {
|
||||||
|
return memory_regions_.at(protection_domain)->lkey;
|
||||||
|
}
|
||||||
|
|
||||||
|
ibv_sge to_scatter_gather_entry(ibv_pd* protection_domain) const {
|
||||||
|
ibv_sge entry;
|
||||||
|
entry.addr = reinterpret_cast<uintptr_t>(data_);
|
||||||
|
entry.length = size();
|
||||||
|
entry.lkey = local_key(protection_domain);
|
||||||
|
return entry;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T* data() {
|
||||||
|
return static_cast<T*>(data_);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T* begin() {
|
||||||
|
return static_cast<T*>(data_);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T* end() {
|
||||||
|
return static_cast<T*>(data_) + size() / sizeof(T);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void* data_;
|
||||||
|
size_t num_bytes_;
|
||||||
|
std::unordered_map<ibv_pd*, ibv_mr*> memory_regions_;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Manipulates an RDMA connection. Enables (among other things)
|
||||||
|
*
|
||||||
|
* - Creating a queue pair
|
||||||
|
* - Sending and receiving
|
||||||
|
* - Checking completion
|
||||||
|
*/
|
||||||
|
struct Connection {
|
||||||
|
ibv_context* ctx;
|
||||||
|
ibv_pd* protection_domain;
|
||||||
|
ibv_cq* completion_queue;
|
||||||
|
ibv_qp* queue_pair;
|
||||||
|
Destination src; // holds the local information
|
||||||
|
|
||||||
|
Connection(ibv_context* ctx_)
|
||||||
|
: ctx(ctx_),
|
||||||
|
protection_domain(nullptr),
|
||||||
|
completion_queue(nullptr),
|
||||||
|
queue_pair(nullptr) {
|
||||||
|
src.local_id = -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Connection(Connection&& c) : Connection(nullptr) {
|
||||||
|
std::swap(ctx, c.ctx);
|
||||||
|
std::swap(protection_domain, c.protection_domain);
|
||||||
|
std::swap(completion_queue, c.completion_queue);
|
||||||
|
std::swap(queue_pair, c.queue_pair);
|
||||||
|
std::swap(src, c.src);
|
||||||
|
}
|
||||||
|
|
||||||
|
Connection(const Connection&) = delete;
|
||||||
|
Connection& operator=(Connection&) = delete;
|
||||||
|
|
||||||
|
~Connection() {
|
||||||
|
if (queue_pair != nullptr) {
|
||||||
|
ibv_destroy_qp(queue_pair);
|
||||||
|
}
|
||||||
|
if (completion_queue != nullptr) {
|
||||||
|
ibv_destroy_cq(completion_queue);
|
||||||
|
}
|
||||||
|
if (protection_domain != nullptr) {
|
||||||
|
ibv_dealloc_pd(protection_domain);
|
||||||
|
}
|
||||||
|
if (ctx != nullptr) {
|
||||||
|
ibv_close_device(ctx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void allocate_protection_domain() {
|
||||||
|
protection_domain = ibv_alloc_pd(ctx);
|
||||||
|
if (protection_domain == nullptr) {
|
||||||
|
throw std::runtime_error("[ibv] Couldn't allocate protection domain");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void create_completion_queue(int num_entries) {
|
||||||
|
completion_queue = ibv_create_cq(ctx, num_entries, nullptr, nullptr, 0);
|
||||||
|
if (completion_queue == nullptr) {
|
||||||
|
throw std::runtime_error("[ibv] Couldn't create completion queue");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void create_queue_pair() {
|
||||||
|
ibv_qp_init_attr init_attr;
|
||||||
|
init_attr.qp_context = ctx;
|
||||||
|
init_attr.qp_context = ctx;
|
||||||
|
init_attr.send_cq = completion_queue;
|
||||||
|
init_attr.recv_cq = completion_queue;
|
||||||
|
init_attr.srq = nullptr;
|
||||||
|
init_attr.cap.max_send_wr = MAX_SEND_WR;
|
||||||
|
init_attr.cap.max_recv_wr = MAX_RECV_WR;
|
||||||
|
init_attr.cap.max_send_sge = 1;
|
||||||
|
init_attr.cap.max_recv_sge = 1;
|
||||||
|
init_attr.cap.max_inline_data = 0;
|
||||||
|
init_attr.qp_type = IBV_QPT_UC;
|
||||||
|
init_attr.sq_sig_all = 0;
|
||||||
|
|
||||||
|
queue_pair = ibv_create_qp(protection_domain, &init_attr);
|
||||||
|
|
||||||
|
if (queue_pair == nullptr) {
|
||||||
|
throw std::runtime_error("[ibv] Couldn't create queue pair");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const Destination& info() {
|
||||||
|
if (queue_pair == nullptr || src.local_id >= 0) {
|
||||||
|
return src;
|
||||||
|
}
|
||||||
|
|
||||||
|
ibv_port_attr port_attr;
|
||||||
|
ibv_query_port(ctx, 1, &port_attr);
|
||||||
|
ibv_gid gid;
|
||||||
|
ibv_query_gid(ctx, 1, 1, &gid);
|
||||||
|
|
||||||
|
src.local_id = port_attr.lid;
|
||||||
|
src.queue_pair_number = queue_pair->qp_num;
|
||||||
|
src.packet_sequence_number = 7; // TODO: Change to sth random
|
||||||
|
src.global_identifier = gid;
|
||||||
|
|
||||||
|
return src;
|
||||||
|
}
|
||||||
|
|
||||||
|
void queue_pair_init() {
|
||||||
|
ibv_qp_attr attr = {};
|
||||||
|
attr.qp_state = IBV_QPS_INIT;
|
||||||
|
attr.port_num = 1;
|
||||||
|
attr.pkey_index = 0;
|
||||||
|
attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ |
|
||||||
|
IBV_ACCESS_REMOTE_WRITE;
|
||||||
|
|
||||||
|
int mask =
|
||||||
|
IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS;
|
||||||
|
|
||||||
|
if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[ibv] Changing queue pair to INIT failed with errno " << status;
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void queue_pair_rtr(const Destination& dst) {
|
||||||
|
ibv_qp_attr attr = {};
|
||||||
|
memset(&attr, 0, sizeof(attr));
|
||||||
|
attr.qp_state = IBV_QPS_RTR;
|
||||||
|
attr.path_mtu = IBV_MTU_1024;
|
||||||
|
attr.rq_psn = dst.packet_sequence_number;
|
||||||
|
attr.dest_qp_num = dst.queue_pair_number;
|
||||||
|
attr.ah_attr.dlid = dst.local_id;
|
||||||
|
attr.ah_attr.sl = 0;
|
||||||
|
attr.ah_attr.src_path_bits = 0;
|
||||||
|
attr.ah_attr.port_num = 1;
|
||||||
|
attr.ah_attr.is_global = 0;
|
||||||
|
|
||||||
|
if (dst.global_identifier.global.interface_id) {
|
||||||
|
attr.ah_attr.is_global = 1;
|
||||||
|
attr.ah_attr.grh.hop_limit = 1;
|
||||||
|
attr.ah_attr.grh.dgid = dst.global_identifier;
|
||||||
|
attr.ah_attr.grh.sgid_index = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int mask = IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN |
|
||||||
|
IBV_QP_RQ_PSN;
|
||||||
|
|
||||||
|
if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[ibv] Changing queue pair to RTR failed with errno " << status;
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void queue_pair_rts() {
|
||||||
|
ibv_qp_attr attr = {};
|
||||||
|
attr.qp_state = IBV_QPS_RTS;
|
||||||
|
attr.sq_psn = src.packet_sequence_number;
|
||||||
|
|
||||||
|
int mask = IBV_QP_STATE | IBV_QP_SQ_PSN;
|
||||||
|
|
||||||
|
if (int status = ibv_modify_qp(queue_pair, &attr, mask); status != 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[ibv] Changing queue pair to RTS failed with errno " << status;
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void post_send(const SharedBuffer& buff, uint64_t work_request_id) {
|
||||||
|
ibv_send_wr work_request, *bad_work_request;
|
||||||
|
|
||||||
|
auto entry = buff.to_scatter_gather_entry(protection_domain);
|
||||||
|
work_request.wr_id = work_request_id;
|
||||||
|
work_request.sg_list = &entry;
|
||||||
|
work_request.num_sge = 1;
|
||||||
|
work_request.opcode = IBV_WR_SEND;
|
||||||
|
work_request.send_flags = IBV_SEND_SIGNALED;
|
||||||
|
work_request.next = nullptr;
|
||||||
|
|
||||||
|
if (int status =
|
||||||
|
ibv_post_send(queue_pair, &work_request, &bad_work_request);
|
||||||
|
status != 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[ibv] Send failed with error code " << status;
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void post_recv(const SharedBuffer& buff, uint64_t work_request_id) {
|
||||||
|
ibv_recv_wr work_request, *bad_work_request;
|
||||||
|
|
||||||
|
auto entry = buff.to_scatter_gather_entry(protection_domain);
|
||||||
|
work_request.wr_id = work_request_id;
|
||||||
|
work_request.sg_list = &entry;
|
||||||
|
work_request.num_sge = 1;
|
||||||
|
work_request.next = nullptr;
|
||||||
|
|
||||||
|
if (int status =
|
||||||
|
ibv_post_recv(queue_pair, &work_request, &bad_work_request);
|
||||||
|
status != 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[ibv] Recv failed with error code " << status;
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Implement a TCP side channel to exchange information about the RDMA
|
||||||
|
* connections.
|
||||||
|
*
|
||||||
|
* Implements a simple all gather where every node sends to rank 0 and rank 0
|
||||||
|
* broadcasts to every node.
|
||||||
|
*/
|
||||||
|
class SideChannel {
|
||||||
|
public:
|
||||||
|
SideChannel(int rank, int size, const char* addr) : rank_(rank), size_(size) {
|
||||||
|
auto address = detail::parse_address(addr);
|
||||||
|
|
||||||
|
if (rank_ == 0) {
|
||||||
|
detail::TCPSocket server(IBV_TAG);
|
||||||
|
server.listen(IBV_TAG, address);
|
||||||
|
|
||||||
|
for (int i = 0; i < size - 1; i++) {
|
||||||
|
sockets_.push_back(server.accept(IBV_TAG));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sockets_.push_back(detail::TCPSocket::connect(
|
||||||
|
IBV_TAG, address, 4, 1000, [](int attempt, int wait) {
|
||||||
|
std::cerr << IBV_TAG << " Connection attempt " << attempt
|
||||||
|
<< " waiting " << wait << " ms" << std::endl;
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SideChannel(const SideChannel&) = delete;
|
||||||
|
SideChannel& operator=(const SideChannel&) = delete;
|
||||||
|
|
||||||
|
SideChannel(SideChannel&& sc)
|
||||||
|
: rank_(sc.rank_), size_(sc.size_), sockets_(std::move(sc.sockets_)) {
|
||||||
|
sc.rank_ = -1;
|
||||||
|
sc.size_ = -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::vector<T> all_gather(const T& v) {
|
||||||
|
std::vector<T> result(size_);
|
||||||
|
|
||||||
|
// T is a container of stuff like std::vector or std::string
|
||||||
|
if constexpr (is_container<T>::value) {
|
||||||
|
using U = typename T::value_type;
|
||||||
|
|
||||||
|
// Share the lengths first and set the communication size to be the
|
||||||
|
// maximum length of the containers.
|
||||||
|
auto lengths = all_gather<int>(v.size());
|
||||||
|
auto max_len = *std::max_element(lengths.begin(), lengths.end());
|
||||||
|
for (auto& s : result) {
|
||||||
|
s.resize(max_len);
|
||||||
|
}
|
||||||
|
|
||||||
|
// All gather of length max_len
|
||||||
|
if (rank_ == 0) {
|
||||||
|
std::copy(v.begin(), v.end(), result[rank_].begin());
|
||||||
|
for (int i = 1; i < size_; i++) {
|
||||||
|
sockets_[i - 1].recv(IBV_TAG, result[i].data(), sizeof(U) * max_len);
|
||||||
|
}
|
||||||
|
for (int i = 1; i < size_; i++) {
|
||||||
|
for (int j = 0; j < size_; j++) {
|
||||||
|
sockets_[i - 1].send(
|
||||||
|
IBV_TAG, result[j].data(), sizeof(U) * max_len);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
std::copy(v.begin(), v.end(), result[rank_].begin());
|
||||||
|
sockets_[0].send(IBV_TAG, result[rank_].data(), sizeof(U) * max_len);
|
||||||
|
for (int i = 0; i < size_; i++) {
|
||||||
|
sockets_[0].recv(IBV_TAG, result[i].data(), sizeof(U) * max_len);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resize the outputs back to the original length
|
||||||
|
for (int i = 0; i < size_; i++) {
|
||||||
|
result[i].resize(lengths[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// T is a scalar
|
||||||
|
else {
|
||||||
|
if (rank_ == 0) {
|
||||||
|
result[rank_] = v;
|
||||||
|
for (int i = 1; i < size_; i++) {
|
||||||
|
sockets_[i - 1].recv(IBV_TAG, &result[i], sizeof(T));
|
||||||
|
}
|
||||||
|
for (int i = 1; i < size_; i++) {
|
||||||
|
sockets_[i - 1].send(IBV_TAG, result.data(), size_ * sizeof(T));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sockets_[0].send(IBV_TAG, &v, sizeof(T));
|
||||||
|
sockets_[0].recv(IBV_TAG, result.data(), size_ * sizeof(T));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int rank_;
|
||||||
|
int size_;
|
||||||
|
std::vector<detail::TCPSocket> sockets_;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Manages a set of connections. Among other things it uses a side channel to
|
||||||
|
* exchange the necessary information and then configure the connections to be
|
||||||
|
* ready for RDMA operations.
|
||||||
|
*/
|
||||||
|
class ConnectionManager {
|
||||||
|
public:
|
||||||
|
ConnectionManager(
|
||||||
|
int rank,
|
||||||
|
const std::vector<std::string>& device_names,
|
||||||
|
const char* coordinator_addr)
|
||||||
|
: rank_(rank),
|
||||||
|
size_(device_names.size()),
|
||||||
|
side_channel_(rank_, size_, coordinator_addr) {
|
||||||
|
create_contexts(device_names);
|
||||||
|
if (connections_[rank_].ctx != nullptr) {
|
||||||
|
throw std::runtime_error("[ibv] Malformed device file");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int rank() const {
|
||||||
|
return rank_;
|
||||||
|
}
|
||||||
|
|
||||||
|
int size() const {
|
||||||
|
return size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs the connection initialization. Namely, after this call all
|
||||||
|
* Connection objects should have a queue pair in RTS state.
|
||||||
|
*/
|
||||||
|
void initialize(int num_buffers, size_t num_bytes) {
|
||||||
|
// Create the queue pairs
|
||||||
|
for (auto& conn : connections_) {
|
||||||
|
if (conn.ctx == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
conn.allocate_protection_domain();
|
||||||
|
conn.create_completion_queue(MAX_SEND_WR + MAX_RECV_WR);
|
||||||
|
conn.create_queue_pair();
|
||||||
|
}
|
||||||
|
|
||||||
|
allocate_buffers(num_buffers, num_bytes);
|
||||||
|
|
||||||
|
// Gather the information to be exchanged
|
||||||
|
std::vector<Destination> info;
|
||||||
|
for (auto& conn : connections_) {
|
||||||
|
info.emplace_back(conn.info());
|
||||||
|
}
|
||||||
|
auto all_infos = side_channel_.all_gather(info);
|
||||||
|
|
||||||
|
// Transition queue pairs to RTS
|
||||||
|
for (int peer = 0; peer < size_; peer++) {
|
||||||
|
if (peer == rank_) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto peer_info = all_infos[peer][rank_];
|
||||||
|
connections_[peer].queue_pair_init();
|
||||||
|
connections_[peer].queue_pair_rtr(peer_info);
|
||||||
|
connections_[peer].queue_pair_rts();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void allocate_buffers(int num_buffers, size_t num_bytes) {
|
||||||
|
// Deregister any buffers and free the memory
|
||||||
|
buffers_.clear();
|
||||||
|
|
||||||
|
// Allocate the memory
|
||||||
|
for (int i = 0; i < num_buffers; i++) {
|
||||||
|
for (int j = 0; j < size_; j++) {
|
||||||
|
buffers_.emplace_back(num_bytes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < num_buffers; i++) {
|
||||||
|
for (int j = 0; j < size_; j++) {
|
||||||
|
// This is our send buffer so register it with all pds so we can send
|
||||||
|
// it to all connected devices.
|
||||||
|
if (j == rank_) {
|
||||||
|
for (auto& conn : connections_) {
|
||||||
|
if (conn.ctx != nullptr) {
|
||||||
|
buffers_[i * size_ + j].register_to_protection_domain(
|
||||||
|
conn.protection_domain);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is the recv buffer from rank j so register it to rank j's
|
||||||
|
// protection domain.
|
||||||
|
else {
|
||||||
|
buffers_[i * size_ + j].register_to_protection_domain(
|
||||||
|
connections_[j].protection_domain);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void send_to(int rank, int buff) {
|
||||||
|
connections_[rank].post_send(
|
||||||
|
buffers_[buff * size_ + rank_],
|
||||||
|
SEND_WR << 16 | buff << 8 | rank);
|
||||||
|
}
|
||||||
|
|
||||||
|
void recv_from(int rank, int buff) {
|
||||||
|
connections_[rank].post_recv(
|
||||||
|
buffers_[buff * size_ + rank], RECV_WR << 16 | buff << 8 | rank);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Poll all connections and save the work completions and return the
|
||||||
|
* corresponding length.
|
||||||
|
*/
|
||||||
|
int poll(int num_completions, ibv_wc* work_completions) {
|
||||||
|
int completions = 0;
|
||||||
|
for (int r = 0; r < size_; r++) {
|
||||||
|
if (r == rank_) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (completions >= num_completions) {
|
||||||
|
return completions;
|
||||||
|
}
|
||||||
|
|
||||||
|
int c = ibv_poll_cq(
|
||||||
|
connections_[r].completion_queue,
|
||||||
|
num_completions - completions,
|
||||||
|
work_completions + completions);
|
||||||
|
|
||||||
|
completions += c;
|
||||||
|
}
|
||||||
|
return completions;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
int poll(int rank, int num_completions, ibv_wc* work_completions) {
|
||||||
|
return ibv_poll_cq(
|
||||||
|
connections_[rank].completion_queue,
|
||||||
|
num_completions,
|
||||||
|
work_completions);
|
||||||
|
}
|
||||||
|
|
||||||
|
SharedBuffer& send_buffer(int buff) {
|
||||||
|
return buffers_[buff * size_ + rank_];
|
||||||
|
}
|
||||||
|
|
||||||
|
SharedBuffer& buffer(int rank, int buff) {
|
||||||
|
return buffers_[buff * size_ + rank];
|
||||||
|
}
|
||||||
|
|
||||||
|
void barrier() {
|
||||||
|
side_channel_.all_gather<int>(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void create_contexts(const std::vector<std::string>& device_names) {
|
||||||
|
int num_devices = 0;
|
||||||
|
ibv_device** devices = ibv_get_device_list(&num_devices);
|
||||||
|
for (auto& name : device_names) {
|
||||||
|
// Empty so add a nullptr context
|
||||||
|
if (name.empty()) {
|
||||||
|
connections_.emplace_back(nullptr);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search for the name and try to open the device
|
||||||
|
for (int i = 0; i < num_devices; i++) {
|
||||||
|
if (name == ibv_get_device_name(devices[i])) {
|
||||||
|
auto ctx = ibv_open_device(devices[i]);
|
||||||
|
if (ctx == nullptr) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[ibv] Could not open device " << name;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
connections_.emplace_back(ctx);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ibv_free_device_list(devices);
|
||||||
|
}
|
||||||
|
|
||||||
|
int rank_;
|
||||||
|
int size_;
|
||||||
|
SideChannel side_channel_;
|
||||||
|
std::vector<Connection> connections_;
|
||||||
|
std::vector<SharedBuffer> buffers_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct SumOp {
|
||||||
|
void operator()(const T* input, T* output, size_t N) const {
|
||||||
|
while (N-- > 0) {
|
||||||
|
*output += *input;
|
||||||
|
input++;
|
||||||
|
output++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<std::string> load_device_names(int rank, const char* dev_file) {
|
||||||
|
std::vector<std::string> device_names;
|
||||||
|
std::ifstream f(dev_file);
|
||||||
|
|
||||||
|
json devices = json::parse(f);
|
||||||
|
devices = devices[rank];
|
||||||
|
for (auto it = devices.begin(); it != devices.end(); it++) {
|
||||||
|
std::string n;
|
||||||
|
if (!it->is_null()) {
|
||||||
|
n = *it;
|
||||||
|
}
|
||||||
|
device_names.emplace_back(std::move(n));
|
||||||
|
}
|
||||||
|
|
||||||
|
return device_names;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::ibv {
|
||||||
|
|
||||||
|
class IBVGroup : public GroupImpl {
|
||||||
|
public:
|
||||||
|
IBVGroup(ConnectionManager cm) : cm_(std::move(cm)), rank_(cm.rank()), size_(cm.size()) {}
|
||||||
|
|
||||||
|
Stream communication_stream(StreamOrDevice s) override {
|
||||||
|
return to_stream(s, Device::cpu);
|
||||||
|
}
|
||||||
|
|
||||||
|
int rank() override {
|
||||||
|
return cm_.rank();
|
||||||
|
}
|
||||||
|
|
||||||
|
int size() override {
|
||||||
|
return cm_.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_sum(const array& input, array& output, Stream stream) override {
|
||||||
|
dispatch_all_types(output.dtype(), [&](auto type_tag) {
|
||||||
|
using T = MLX_GET_TYPE(type_tag);
|
||||||
|
all_reduce<T>(input, output, stream, SumOp<T>{});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_max(const array& input, array& output, Stream stream) override {
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_min(const array& input, array& output, Stream stream) override {
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_gather(const array& input, array& output, Stream stream) override {
|
||||||
|
}
|
||||||
|
|
||||||
|
void send(const array& input, int dst, Stream stream) override {
|
||||||
|
}
|
||||||
|
|
||||||
|
void recv(array& out, int src, Stream stream) override {
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||||
|
throw std::runtime_error("[ibv] Group split not supported.");
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void post_recv_all(int buffer) {
|
||||||
|
for (int i = 0; i < size_; i++) {
|
||||||
|
if (i == rank_) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
cm_.recv_from(i, buffer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void post_send_all(int buffer) {
|
||||||
|
for (int i = 0; i < size_; i++) {
|
||||||
|
if (i == rank_) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
cm_.send_to(i, buffer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename ReduceOp>
|
||||||
|
void all_reduce(
|
||||||
|
const array& input,
|
||||||
|
array& output,
|
||||||
|
Stream stream,
|
||||||
|
ReduceOp reduce_op) {
|
||||||
|
auto in_ptr = input.data<T>();
|
||||||
|
auto out_ptr = output.data<T>();
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_output_array(output);
|
||||||
|
encoder.dispatch([in_ptr, out_ptr, size = input.size(), this, reduce_op]() {
|
||||||
|
// If not inplace all reduce then copy the input to the output first
|
||||||
|
if (in_ptr != out_ptr) {
|
||||||
|
std::memcpy(out_ptr, in_ptr, size * sizeof(T));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fully connected all reduce
|
||||||
|
T* data = out_ptr;
|
||||||
|
constexpr int64_t N = BUFFER_SIZE / sizeof(T);
|
||||||
|
int64_t total = static_cast<int64_t>(size);
|
||||||
|
int64_t offset = N;
|
||||||
|
int a = 0, b = 1;
|
||||||
|
|
||||||
|
int mask_init = 1 << rank_;
|
||||||
|
int mask_target = (1 << size_) - 1;
|
||||||
|
|
||||||
|
// Handle the first piece of data.
|
||||||
|
post_recv_all(a);
|
||||||
|
std::copy(
|
||||||
|
data,
|
||||||
|
data + std::min(N, total),
|
||||||
|
cm_.send_buffer(a).begin<T>());
|
||||||
|
post_send_all(a);
|
||||||
|
|
||||||
|
int mask_a_send = mask_init;
|
||||||
|
int mask_a_recv = mask_init;
|
||||||
|
int mask_b_send = mask_init;
|
||||||
|
int mask_b_recv = mask_init;
|
||||||
|
|
||||||
|
while (offset < total) {
|
||||||
|
// While the previous chunk is in flight copy to the next send buffer
|
||||||
|
std::copy(
|
||||||
|
data + offset,
|
||||||
|
data + std::min(offset + N, total),
|
||||||
|
cm_.send_buffer(b).begin<T>());
|
||||||
|
|
||||||
|
// Send if the previous send is already done
|
||||||
|
for (int i = 0, m = 1; i < size_; i++, m *= 2) {
|
||||||
|
if (i == rank_) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (mask_a_send & m) {
|
||||||
|
cm_.send_to(i, b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recv the next buffer if the previous one is already done
|
||||||
|
for (int i = 0, m = 1; i < size_; i++, m *= 2) {
|
||||||
|
if (i == rank_) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (mask_a_recv & m) {
|
||||||
|
cm_.recv_from(i, b);
|
||||||
|
reduce_op(
|
||||||
|
cm_.buffer(i, a).begin<T>(),
|
||||||
|
data + std::max(offset - N, 0LL),
|
||||||
|
std::min(N, total - offset));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Loop until this chunk is all done
|
||||||
|
while (mask_a_send != mask_target || mask_a_recv != mask_target) {
|
||||||
|
ibv_wc wc[8];
|
||||||
|
int n = cm_.poll(8, wc);
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
int work_type = wc[i].wr_id >> 16;
|
||||||
|
int buff = (wc[i].wr_id >> 8) & 0xff;
|
||||||
|
int rank = wc[i].wr_id & 0xff;
|
||||||
|
|
||||||
|
if (work_type == SEND_WR) {
|
||||||
|
if (buff == a) {
|
||||||
|
cm_.send_to(rank, b);
|
||||||
|
mask_a_send |= 1 << rank;
|
||||||
|
} else {
|
||||||
|
mask_b_send |= 1 << rank;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (buff == a) {
|
||||||
|
cm_.recv_from(rank, b);
|
||||||
|
mask_a_recv |= 1 << rank;
|
||||||
|
reduce_op(
|
||||||
|
cm_.buffer(rank, a).begin<T>(),
|
||||||
|
data + std::max(offset - N, 0LL),
|
||||||
|
std::min(N, total - offset));
|
||||||
|
} else {
|
||||||
|
mask_b_recv |= 1 << rank;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::swap(a, b);
|
||||||
|
mask_a_send = mask_b_send;
|
||||||
|
mask_a_recv = mask_b_recv;
|
||||||
|
mask_b_send = mask_init;
|
||||||
|
mask_b_recv = mask_init;
|
||||||
|
offset += N;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
while (mask_a_send != mask_target || mask_a_recv != mask_target) {
|
||||||
|
ibv_wc wc[8];
|
||||||
|
int n = cm_.poll(8, wc);
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
int work_type = wc[i].wr_id >> 16;
|
||||||
|
int buff = (wc[i].wr_id >> 8) & 0xff;
|
||||||
|
int rank = wc[i].wr_id & 0xff;
|
||||||
|
|
||||||
|
if (work_type == SEND_WR) {
|
||||||
|
mask_a_send |= 1 << rank;
|
||||||
|
} else {
|
||||||
|
mask_a_recv |= 1 << rank;
|
||||||
|
reduce_op(
|
||||||
|
cm_.buffer(rank, a).begin<T>(),
|
||||||
|
data + offset - N,
|
||||||
|
std::min(N, total + N - offset));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
ConnectionManager cm_;
|
||||||
|
int rank_;
|
||||||
|
int size_;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||||
|
const char* dev_file = std::getenv("MLX_IBV_DEVICES");
|
||||||
|
const char* coordinator = std::getenv("MLX_IBV_COORDINATOR");
|
||||||
|
const char* rank_str = std::getenv("MLX_RANK");
|
||||||
|
const char* ring_verbose = std::getenv("MLX_IBV_VERBOSE");
|
||||||
|
|
||||||
|
if (!dev_file || !coordinator || !rank_str) {
|
||||||
|
if (strict) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[ibv] You need to provide via environment variables a rank (MLX_RANK), "
|
||||||
|
<< "a device file (MLX_IBV_DEVICES) and a coordinator ip/port (MLX_IBV_COORDINATOR) "
|
||||||
|
<< "but provided MLX_RANK=\"" << ((rank_str) ? rank_str : "")
|
||||||
|
<< "\", MLX_IBV_DEVICES=\"" << ((dev_file) ? dev_file : "")
|
||||||
|
<< "\" and MLX_IBV_COORDINATOR=\""
|
||||||
|
<< ((coordinator) ? coordinator : "");
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto rank = std::atoi(rank_str);
|
||||||
|
auto device_names = load_device_names(rank, dev_file);
|
||||||
|
|
||||||
|
auto cm = ConnectionManager(rank, device_names, coordinator);
|
||||||
|
if (cm.size() > MAX_PEERS) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[ibv] The maximum number of supported peers is "
|
||||||
|
<< MAX_PEERS << " but " << cm.size() << " was provided";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
cm.initialize(NUM_BUFFERS, BUFFER_SIZE);
|
||||||
|
cm.barrier();
|
||||||
|
|
||||||
|
return std::make_shared<IBVGroup>(std::move(cm));
|
||||||
|
|
||||||
|
//cm.recv_from(rank ^ 1, 0);
|
||||||
|
//cm.barrier();
|
||||||
|
//std::fill(
|
||||||
|
// cm.buffer(rank, 0).begin<int>(),
|
||||||
|
// cm.buffer(rank, 0).end<int>(),
|
||||||
|
// 2 * rank + 1);
|
||||||
|
//cm.barrier();
|
||||||
|
//cm.send_to(rank ^ 1, 0);
|
||||||
|
|
||||||
|
//ibv_wc wc[8];
|
||||||
|
//bool sent = false, recvd = false;
|
||||||
|
//while (!(sent && recvd)) {
|
||||||
|
// int num_completions = cm.poll(8, wc);
|
||||||
|
// for (int i = 0; i < num_completions; i++) {
|
||||||
|
// if (wc[i].status != IBV_WC_SUCCESS) {
|
||||||
|
// std::cout << "Error " << wc[i].status << std::endl;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// int work_type = wc[i].wr_id >> 16;
|
||||||
|
// int buff = (wc[i].wr_id >> 8) & 0xff;
|
||||||
|
// int rank = wc[i].wr_id & 0xff;
|
||||||
|
|
||||||
|
// std::cout << work_type << " " << buff << " " << rank << std::endl;
|
||||||
|
// print_wc(wc[i]);
|
||||||
|
|
||||||
|
// sent |= (work_type == SEND_WR);
|
||||||
|
// recvd |= (work_type == RECV_WR);
|
||||||
|
// }
|
||||||
|
//}
|
||||||
|
|
||||||
|
//std::cout << rank << " ours " << *cm.buffer(rank, 0).data<int>() << std::endl;
|
||||||
|
//std::cout << rank << " theirs " << *cm.buffer(rank ^ 1, 0).data<int>()
|
||||||
|
// << std::endl;
|
||||||
|
|
||||||
|
//return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::ibv
|
||||||
12
mlx/distributed/ibv/ibv.h
Normal file
12
mlx/distributed/ibv/ibv.h
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/distributed/distributed.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::ibv {
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
|
||||||
|
bool is_available();
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict = false);
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::ibv
|
||||||
@@ -1,9 +1,6 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#include <arpa/inet.h>
|
|
||||||
#include <fcntl.h>
|
#include <fcntl.h>
|
||||||
#include <netdb.h>
|
|
||||||
#include <netinet/in.h>
|
|
||||||
#include <netinet/tcp.h>
|
#include <netinet/tcp.h>
|
||||||
#include <sys/socket.h>
|
#include <sys/socket.h>
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
@@ -22,6 +19,7 @@
|
|||||||
#include "mlx/backend/cpu/encoder.h"
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/distributed_impl.h"
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
|
#include "mlx/distributed/utils.h"
|
||||||
#include "mlx/threadpool.h"
|
#include "mlx/threadpool.h"
|
||||||
|
|
||||||
#ifndef SOL_TCP
|
#ifndef SOL_TCP
|
||||||
@@ -94,6 +92,7 @@ constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024;
|
|||||||
constexpr const size_t ALL_SUM_BUFFERS = 2;
|
constexpr const size_t ALL_SUM_BUFFERS = 2;
|
||||||
constexpr const int CONN_ATTEMPTS = 5;
|
constexpr const int CONN_ATTEMPTS = 5;
|
||||||
constexpr const int CONN_WAIT = 1000;
|
constexpr const int CONN_WAIT = 1000;
|
||||||
|
constexpr const char* RING_TAG = "[ring]";
|
||||||
|
|
||||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
@@ -296,55 +295,6 @@ class CommunicationThreads {
|
|||||||
std::unordered_map<int, SocketThread> threads_;
|
std::unordered_map<int, SocketThread> threads_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct address_t {
|
|
||||||
sockaddr_storage addr;
|
|
||||||
socklen_t len;
|
|
||||||
|
|
||||||
const sockaddr* get() const {
|
|
||||||
return (struct sockaddr*)&addr;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse a sockaddr from an ip and port provided as strings.
|
|
||||||
*/
|
|
||||||
address_t parse_address(const std::string& ip, const std::string& port) {
|
|
||||||
struct addrinfo hints, *res;
|
|
||||||
memset(&hints, 0, sizeof(hints));
|
|
||||||
hints.ai_family = AF_UNSPEC;
|
|
||||||
hints.ai_socktype = SOCK_STREAM;
|
|
||||||
|
|
||||||
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
|
|
||||||
if (status != 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "Can't parse address " << ip << ":" << port;
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
address_t result;
|
|
||||||
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
|
|
||||||
result.len = res->ai_addrlen;
|
|
||||||
freeaddrinfo(res);
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse a sockaddr provided as an <ip>:<port> string.
|
|
||||||
*/
|
|
||||||
address_t parse_address(const std::string& ip_port) {
|
|
||||||
auto colon = ip_port.find(":");
|
|
||||||
if (colon == std::string::npos) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "Can't parse address " << ip_port;
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
std::string ip(ip_port.begin(), ip_port.begin() + colon);
|
|
||||||
std::string port(ip_port.begin() + colon + 1, ip_port.end());
|
|
||||||
|
|
||||||
return parse_address(ip, port);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Load all addresses from the json hostfile. The hostfile is a list of
|
* Load all addresses from the json hostfile. The hostfile is a list of
|
||||||
* addresses in order of rank. For each rank there can be many addresses so
|
* addresses in order of rank. For each rank there can be many addresses so
|
||||||
@@ -357,15 +307,15 @@ address_t parse_address(const std::string& ip_port) {
|
|||||||
* ["ip3:5000", "ip3:5001"],
|
* ["ip3:5000", "ip3:5001"],
|
||||||
* ]
|
* ]
|
||||||
*/
|
*/
|
||||||
std::vector<std::vector<address_t>> load_nodes(const char* hostfile) {
|
std::vector<std::vector<detail::address_t>> load_nodes(const char* hostfile) {
|
||||||
std::vector<std::vector<address_t>> nodes;
|
std::vector<std::vector<detail::address_t>> nodes;
|
||||||
std::ifstream f(hostfile);
|
std::ifstream f(hostfile);
|
||||||
|
|
||||||
json hosts = json::parse(f);
|
json hosts = json::parse(f);
|
||||||
for (auto& h : hosts) {
|
for (auto& h : hosts) {
|
||||||
std::vector<address_t> host;
|
std::vector<detail::address_t> host;
|
||||||
for (auto& ips : h) {
|
for (auto& ips : h) {
|
||||||
host.push_back(parse_address(ips.get<std::string>()));
|
host.push_back(std::move(detail::parse_address(ips.get<std::string>())));
|
||||||
}
|
}
|
||||||
nodes.push_back(std::move(host));
|
nodes.push_back(std::move(host));
|
||||||
}
|
}
|
||||||
@@ -377,73 +327,15 @@ std::vector<std::vector<address_t>> load_nodes(const char* hostfile) {
|
|||||||
* Create a socket and accept one connection for each of the provided
|
* Create a socket and accept one connection for each of the provided
|
||||||
* addresses.
|
* addresses.
|
||||||
*/
|
*/
|
||||||
std::vector<int> accept_connections(const std::vector<address_t>& addresses) {
|
std::vector<int> accept_connections(
|
||||||
|
const std::vector<detail::address_t>& addresses) {
|
||||||
std::vector<int> sockets;
|
std::vector<int> sockets;
|
||||||
int success;
|
int success;
|
||||||
|
|
||||||
for (auto& address : addresses) {
|
for (auto& address : addresses) {
|
||||||
// Create the socket to wait for connections from the peers
|
detail::TCPSocket socket(RING_TAG);
|
||||||
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
socket.listen(RING_TAG, address);
|
||||||
if (sock < 0) {
|
sockets.push_back(socket.accept(RING_TAG));
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Couldn't create socket (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure we can launch immediately after shutdown by setting the
|
|
||||||
// reuseaddr option so that we don't get address already in use errors
|
|
||||||
int enable = 1;
|
|
||||||
success = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
|
|
||||||
if (success < 0) {
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Couldn't enable reuseaddr (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
success = setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
|
|
||||||
if (success < 0) {
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Couldn't enable reuseport (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bind the socket to the address and port
|
|
||||||
success = bind(sock, address.get(), address.len);
|
|
||||||
if (success < 0) {
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Couldn't bind socket (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for connections
|
|
||||||
success = listen(sock, 0);
|
|
||||||
if (success < 0) {
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Couldn't listen (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
int peer_socket = accept(sock, nullptr, nullptr);
|
|
||||||
if (peer_socket < 0) {
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Accept failed (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close the listening socket
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
|
|
||||||
sockets.push_back(peer_socket);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return sockets;
|
return sockets;
|
||||||
@@ -454,55 +346,33 @@ std::vector<int> accept_connections(const std::vector<address_t>& addresses) {
|
|||||||
* provided addresses.
|
* provided addresses.
|
||||||
*/
|
*/
|
||||||
std::vector<int> make_connections(
|
std::vector<int> make_connections(
|
||||||
const std::vector<address_t>& addresses,
|
const std::vector<detail::address_t>& addresses,
|
||||||
bool verbose) {
|
bool verbose) {
|
||||||
std::vector<int> sockets;
|
std::vector<int> sockets;
|
||||||
int success;
|
int success;
|
||||||
|
|
||||||
for (auto& address : addresses) {
|
for (auto& address : addresses) {
|
||||||
int sock;
|
sockets.push_back(detail::TCPSocket::connect(
|
||||||
|
RING_TAG,
|
||||||
// Attempt to connect to the peer CONN_ATTEMPTS times with exponential
|
address,
|
||||||
// backoff. TODO: Do we need that?
|
CONN_ATTEMPTS,
|
||||||
for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) {
|
CONN_WAIT,
|
||||||
// Create the socket
|
[verbose](int attempt, int wait) {
|
||||||
sock = socket(AF_INET, SOCK_STREAM, 0);
|
log_info(
|
||||||
if (sock < 0) {
|
verbose,
|
||||||
std::ostringstream msg;
|
"Attempt",
|
||||||
msg << "[ring] Couldn't create socket (error: " << errno << ")";
|
attempt,
|
||||||
throw std::runtime_error(msg.str());
|
"waiting",
|
||||||
}
|
wait,
|
||||||
|
"ms (error:",
|
||||||
if (attempt > 0) {
|
errno,
|
||||||
int wait = (1 << (attempt - 1)) * CONN_WAIT;
|
")");
|
||||||
log_info(
|
}));
|
||||||
verbose,
|
|
||||||
"Attempt",
|
|
||||||
attempt,
|
|
||||||
"wait",
|
|
||||||
wait,
|
|
||||||
"ms (error:",
|
|
||||||
errno,
|
|
||||||
")");
|
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
|
|
||||||
}
|
|
||||||
|
|
||||||
success = connect(sock, address.get(), address.len);
|
|
||||||
if (success == 0) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (success < 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Couldn't connect (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
sockets.push_back(sock);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return sockets;
|
return sockets;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct SumOp {
|
struct SumOp {
|
||||||
void operator()(const T* input, T* output, size_t N) {
|
void operator()(const T* input, T* output, size_t N) {
|
||||||
@@ -540,7 +410,10 @@ struct MinOp {
|
|||||||
|
|
||||||
class RingGroup : public GroupImpl {
|
class RingGroup : public GroupImpl {
|
||||||
public:
|
public:
|
||||||
RingGroup(int rank, std::vector<std::vector<address_t>> nodes, bool verbose)
|
RingGroup(
|
||||||
|
int rank,
|
||||||
|
std::vector<std::vector<detail::address_t>> nodes,
|
||||||
|
bool verbose)
|
||||||
: rank_(rank), verbose_(verbose), pool_(0) {
|
: rank_(rank), verbose_(verbose), pool_(0) {
|
||||||
if (rank_ > 0 && rank_ >= nodes.size()) {
|
if (rank_ > 0 && rank_ >= nodes.size()) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
|
|||||||
189
mlx/distributed/utils.cpp
Normal file
189
mlx/distributed/utils.cpp
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <netdb.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <sstream>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
#include "mlx/distributed/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::detail {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse a sockaddr from an ip and port provided as strings.
|
||||||
|
*/
|
||||||
|
address_t parse_address(const std::string& ip, const std::string& port) {
|
||||||
|
struct addrinfo hints, *res;
|
||||||
|
memset(&hints, 0, sizeof(hints));
|
||||||
|
hints.ai_family = AF_UNSPEC;
|
||||||
|
hints.ai_socktype = SOCK_STREAM;
|
||||||
|
|
||||||
|
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
|
||||||
|
if (status != 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Can't parse address " << ip << ":" << port;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
address_t result;
|
||||||
|
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
|
||||||
|
result.len = res->ai_addrlen;
|
||||||
|
freeaddrinfo(res);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse a sockaddr provided as an <ip>:<port> string.
|
||||||
|
*/
|
||||||
|
address_t parse_address(const std::string& ip_port) {
|
||||||
|
auto colon = ip_port.find(":");
|
||||||
|
if (colon == std::string::npos) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Can't parse address " << ip_port;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
std::string ip(ip_port.begin(), ip_port.begin() + colon);
|
||||||
|
std::string port(ip_port.begin() + colon + 1, ip_port.end());
|
||||||
|
|
||||||
|
return parse_address(ip, port);
|
||||||
|
}
|
||||||
|
|
||||||
|
TCPSocket::TCPSocket(const char* tag) {
|
||||||
|
sock_ = socket(AF_INET, SOCK_STREAM, 0);
|
||||||
|
if (sock_ < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't create socket (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TCPSocket::TCPSocket(TCPSocket&& s) {
|
||||||
|
sock_ = s.sock_;
|
||||||
|
s.sock_ = -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
TCPSocket::TCPSocket(int s) : sock_(s) {}
|
||||||
|
|
||||||
|
TCPSocket::~TCPSocket() {
|
||||||
|
if (sock_ > 0) {
|
||||||
|
shutdown(sock_, 2);
|
||||||
|
close(sock_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TCPSocket::listen(const char* tag, const address_t& addr) {
|
||||||
|
int success;
|
||||||
|
|
||||||
|
// Make sure we can launch immediately after shutdown by setting the
|
||||||
|
// reuseaddr option so that we don't get address already in use errors
|
||||||
|
int enable = 1;
|
||||||
|
success = setsockopt(sock_, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
|
||||||
|
if (success < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't enable reuseaddr (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
success = setsockopt(sock_, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
|
||||||
|
if (success < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't enable reuseport (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bind the socket to the address and port
|
||||||
|
success = bind(sock_, addr.get(), addr.len);
|
||||||
|
if (success < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't bind socket (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare waiting for connections
|
||||||
|
success = ::listen(sock_, 0);
|
||||||
|
if (success < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't listen (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TCPSocket TCPSocket::accept(const char* tag) {
|
||||||
|
int peer = ::accept(sock_, nullptr, nullptr);
|
||||||
|
if (peer < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Accept failed (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
return TCPSocket(peer);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TCPSocket::send(const char* tag, const void* data, size_t len) {
|
||||||
|
while (len > 0) {
|
||||||
|
auto n = ::send(sock_, data, len, 0);
|
||||||
|
if (n <= 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Send failed with errno=" << errno;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
len -= n;
|
||||||
|
data = static_cast<const char*>(data) + n;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TCPSocket::recv(const char* tag, void* data, size_t len) {
|
||||||
|
while (len > 0) {
|
||||||
|
auto n = ::recv(sock_, data, len, 0);
|
||||||
|
if (n <= 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Recv failed with errno=" << errno;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
len -= n;
|
||||||
|
data = static_cast<char*>(data) + n;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TCPSocket TCPSocket::connect(
|
||||||
|
const char* tag,
|
||||||
|
const address_t& addr,
|
||||||
|
int num_retries,
|
||||||
|
int wait,
|
||||||
|
std::function<void(int, int)> cb) {
|
||||||
|
int sock, success;
|
||||||
|
|
||||||
|
// Attempt to connect `num_retries` times with exponential backoff.
|
||||||
|
for (int attempt = 0; attempt < num_retries; attempt++) {
|
||||||
|
// Create the socket
|
||||||
|
sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||||
|
if (sock < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't create socket to connect (error: " << errno
|
||||||
|
<< ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
success = ::connect(sock, addr.get(), addr.len);
|
||||||
|
if (success == 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
cb(attempt, wait);
|
||||||
|
if (wait > 0) {
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
|
||||||
|
}
|
||||||
|
|
||||||
|
wait <<= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (success < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't connect (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
return TCPSocket(sock);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::detail
|
||||||
62
mlx/distributed/utils.h
Normal file
62
mlx/distributed/utils.h
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <sys/socket.h>
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::detail {
|
||||||
|
|
||||||
|
struct address_t {
|
||||||
|
sockaddr_storage addr;
|
||||||
|
socklen_t len;
|
||||||
|
|
||||||
|
const sockaddr* get() const {
|
||||||
|
return (struct sockaddr*)&addr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse a sockaddr from an ip and port provided as strings.
|
||||||
|
*/
|
||||||
|
address_t parse_address(const std::string& ip, const std::string& port);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse a sockaddr provided as an <ip>:<port> string.
|
||||||
|
*/
|
||||||
|
address_t parse_address(const std::string& ip_port);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Small wrapper over a TCP socket to simplify initiating connections.
|
||||||
|
*/
|
||||||
|
class TCPSocket {
|
||||||
|
public:
|
||||||
|
TCPSocket(const char* tag);
|
||||||
|
TCPSocket(const TCPSocket&) = delete;
|
||||||
|
TCPSocket& operator=(const TCPSocket&) = delete;
|
||||||
|
TCPSocket(TCPSocket&& s);
|
||||||
|
~TCPSocket();
|
||||||
|
|
||||||
|
void listen(const char* tag, const address_t& addr);
|
||||||
|
TCPSocket accept(const char* tag);
|
||||||
|
|
||||||
|
void send(const char* tag, const void* data, size_t len);
|
||||||
|
void recv(const char* tag, void* data, size_t len);
|
||||||
|
|
||||||
|
operator int() const {
|
||||||
|
return sock_;
|
||||||
|
}
|
||||||
|
|
||||||
|
static TCPSocket connect(
|
||||||
|
const char* tag,
|
||||||
|
const address_t& addr,
|
||||||
|
int num_retries = 1,
|
||||||
|
int wait = 0,
|
||||||
|
std::function<void(int, int)> cb = nullptr);
|
||||||
|
|
||||||
|
private:
|
||||||
|
TCPSocket(int sock);
|
||||||
|
|
||||||
|
int sock_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::detail
|
||||||
Reference in New Issue
Block a user