From a1212b4e44515af9eee049823a50e4d189914fd5 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Thu, 30 Oct 2025 16:25:11 -0700 Subject: [PATCH] WIP (distributed) --- mlx/distributed/primitives.cpp | 8 +-- mlx/distributed/ring/ring.cpp | 121 +++++++++++++++++---------------- 2 files changed, 67 insertions(+), 62 deletions(-) diff --git a/mlx/distributed/primitives.cpp b/mlx/distributed/primitives.cpp index 5e8d5327a..0c87172be 100644 --- a/mlx/distributed/primitives.cpp +++ b/mlx/distributed/primitives.cpp @@ -27,7 +27,7 @@ std::pair, std::vector> AllReduce::vmap( } std::vector AllReduce::jvp( - const std::vector& primals, + const std::vector& /* primals */, const std::vector& tangents, const std::vector&) { switch (reduce_type_) { @@ -44,10 +44,10 @@ std::vector AllReduce::jvp( } std::vector AllReduce::vjp( - const std::vector& primals, + const std::vector& /* primals */, const std::vector& cotangents, const std::vector&, - const std::vector& outputs) { + const std::vector& /* outputs */) { return cotangents; } @@ -58,7 +58,7 @@ std::pair, std::vector> AllGather::vmap( } std::vector AllGather::jvp( - const std::vector& primals, + const std::vector& /* primals */, const std::vector& tangents, const std::vector&) { return {all_gather(tangents[0], group(), stream())}; diff --git a/mlx/distributed/ring/ring.cpp b/mlx/distributed/ring/ring.cpp index ac55ea30b..702cf7a4c 100644 --- a/mlx/distributed/ring/ring.cpp +++ b/mlx/distributed/ring/ring.cpp @@ -90,8 +90,8 @@ namespace mlx::core::distributed::ring { -constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024; -constexpr const size_t ALL_SUM_BUFFERS = 2; +constexpr const int64_t ALL_SUM_SIZE = 8 * 1024 * 1024; +constexpr const int64_t ALL_SUM_BUFFERS = 2; constexpr const int CONN_ATTEMPTS = 5; constexpr const int CONN_WAIT = 1000; @@ -141,27 +141,27 @@ class SocketThread { } template - std::future send(const T* buffer, size_t size) { + std::future send(const T* buffer, int64_t size) { return send_impl(reinterpret_cast(buffer), size * sizeof(T)); } template - std::future recv(T* buffer, size_t size) { + std::future recv(T* buffer, int64_t size) { return recv_impl(reinterpret_cast(buffer), size * sizeof(T)); } private: struct SocketTask { - SocketTask(void* b, size_t s, std::promise&& p) + SocketTask(void* b, int64_t s, std::promise&& p) : buffer(b), size(s), promise(std::move(p)) {} SocketTask(SocketTask&& t) : buffer(t.buffer), size(t.size), promise(std::move(t.promise)) {} void* buffer; - size_t size; + int64_t size; std::promise promise; }; - std::future send_impl(const char* buffer, size_t size) { + std::future send_impl(const char* buffer, int64_t size) { std::promise send_completed_promise; auto send_completed_future = send_completed_promise.get_future(); if (size == 0) { @@ -178,7 +178,7 @@ class SocketThread { return send_completed_future; } - std::future recv_impl(char* buffer, size_t size) { + std::future recv_impl(char* buffer, int64_t size) { std::promise recv_completed_promise; auto recv_completed_future = recv_completed_promise.get_future(); if (size == 0) { @@ -232,7 +232,7 @@ class SocketThread { if (!recvs_.empty()) { auto& task = recvs_.front(); - ssize_t r = ::recv(fd_, task.buffer, task.size, 0); + int64_t r = ::recv(fd_, task.buffer, task.size, 0); if (r > 0) { task.buffer = static_cast(task.buffer) + r; task.size -= r; @@ -246,7 +246,7 @@ class SocketThread { } if (!sends_.empty()) { auto& task = sends_.front(); - ssize_t r = ::send(fd_, task.buffer, task.size, 0); + int64_t r = ::send(fd_, task.buffer, task.size, 0); if (r > 0) { task.buffer = static_cast(task.buffer) + r; task.size -= r; @@ -283,12 +283,12 @@ class CommunicationThreads { } template - std::future send(int socket, T* buffer, size_t size) { + std::future send(int socket, T* buffer, int64_t size) { return threads_.at(socket).send(buffer, size); } template - std::future recv(int socket, T* buffer, size_t size) { + std::future recv(int socket, T* buffer, int64_t size) { return threads_.at(socket).recv(buffer, size); } @@ -505,7 +505,7 @@ std::vector make_connections( } template struct SumOp { - void operator()(const T* input, T* output, size_t N) { + void operator()(const T* input, T* output, int64_t N) { while (N-- > 0) { *output += *input; input++; @@ -516,7 +516,7 @@ struct SumOp { template struct MaxOp { - void operator()(const T* input, T* output, size_t N) { + void operator()(const T* input, T* output, int64_t N) { while (N-- > 0) { *output = std::max(*output, *input); input++; @@ -527,7 +527,7 @@ struct MaxOp { template struct MinOp { - void operator()(const T* input, T* output, size_t N) { + void operator()(const T* input, T* output, int64_t N) { while (N-- > 0) { *output = std::min(*output, *input); input++; @@ -542,7 +542,7 @@ class RingGroup : public GroupImpl { public: RingGroup(int rank, std::vector> nodes, bool verbose) : rank_(rank), verbose_(verbose), pool_(0) { - if (rank_ > 0 && rank_ >= nodes.size()) { + if (rank_ > 0 && rank_ >= std::ssize(nodes)) { throw std::runtime_error( "[ring] Rank cannot be larger than the size of the group"); } @@ -589,7 +589,7 @@ class RingGroup : public GroupImpl { // Configure all sockets to use TCP no delay. int one = 1; - for (int i = 0; i < sockets_right_.size(); i++) { + for (int64_t i = 0; i < std::ssize(sockets_right_); i++) { setsockopt(sockets_right_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one)); setsockopt(sockets_left_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one)); } @@ -646,7 +646,8 @@ class RingGroup : public GroupImpl { output, all_reduce>(input, output, stream, MinOp())); } - std::shared_ptr split(int color, int key = -1) override { + std::shared_ptr split(int /* color */, int /* key */ = -1) + override { throw std::runtime_error("[ring] Group split not supported."); } @@ -658,15 +659,15 @@ class RingGroup : public GroupImpl { nbytes = input.nbytes(), output_ptr = output.data(), this]() { - constexpr size_t min_send_size = 262144; - size_t n_gathers = std::max( - std::min( + constexpr int64_t min_send_size = 262144; + int64_t n_gathers = std::max( + std::min( sockets_right_.size() + sockets_left_.size(), nbytes / min_send_size), - size_t(1)); - size_t bytes_per_gather = ceildiv(nbytes, n_gathers); + 1); + int64_t bytes_per_gather = ceildiv(nbytes, n_gathers); std::vector> all_gathers; - for (int i = 0; i < n_gathers; i++) { + for (int64_t i = 0; i < n_gathers; i++) { auto offset = i * bytes_per_gather; all_gathers.emplace_back(pool_.enqueue(std::bind( &RingGroup::all_gather_impl, @@ -742,10 +743,14 @@ class RingGroup : public GroupImpl { auto out_ptr = output.data(); auto& encoder = cpu::get_command_encoder(stream); encoder.set_output_array(output); - encoder.dispatch([in_ptr, out_ptr, size = input.size(), this, reduce_op]() { + encoder.dispatch([in_ptr, + out_ptr, + size = static_cast(input.size()), + this, + reduce_op]() { // If the input data cannot be split into size_ segments then copy it and // all reduce a local buffer prefilled with 0s. - size_t nbytes = size * sizeof(T); + int64_t nbytes = size * sizeof(T); if (size < size_) { // TODO: Maybe allocate dynamically so we don't have the constraint // below? @@ -778,16 +783,16 @@ class RingGroup : public GroupImpl { // Split the all reduces so that each member has at least 1 buffer to // send/recv per segment. - constexpr size_t min_send_size = 262144; - size_t n_reduces = std::max( - std::min( + constexpr int64_t min_send_size = 262144; + int64_t n_reduces = std::max( + std::min( sockets_right_.size() + sockets_left_.size(), nbytes / (size_ * min_send_size)), - size_t(1)); - size_t step = ceildiv(size, n_reduces); + 1); + int64_t step = ceildiv(size, n_reduces); std::vector> all_sums; - for (int i = 0; i < n_reduces; i++) { + for (int64_t i = 0; i < n_reduces; i++) { all_sums.emplace_back(pool_.enqueue(std::bind( &RingGroup::all_reduce_impl, this, @@ -810,7 +815,7 @@ class RingGroup : public GroupImpl { void all_reduce_impl( T* buffer, T* data, - size_t data_size, + int64_t data_size, int socket_right, int socket_left, int direction, @@ -821,10 +826,10 @@ class RingGroup : public GroupImpl { // We split the data into `size_` segments of size `segment_size` and each // of these in smaller segments of ALL_SUM_SIZE which we 'll call packets. - size_t segment_size = ceildiv(data_size, size_); - size_t BUFFER_SIZE = std::max( - size_t(32768), std::min(ALL_SUM_SIZE / sizeof(T), segment_size / 2)); - size_t n_packets = ceildiv(segment_size, BUFFER_SIZE); + int64_t segment_size = ceildiv(data_size, size_); + int64_t BUFFER_SIZE = std::max( + 32768, std::min(ALL_SUM_SIZE / sizeof(T), segment_size / 2)); + int64_t n_packets = ceildiv(segment_size, BUFFER_SIZE); // Initial segments int send_segment = rank_; @@ -833,21 +838,21 @@ class RingGroup : public GroupImpl { // Plan the whole reduce in terms of sends and recvs as indices in data. // It makes the actual async send and recv a bit simpler to follow when // there are less offset calculations around. - std::vector> send_plan; - std::vector> recv_plan; + std::vector> send_plan; + std::vector> recv_plan; // Two times the same send/recv operations, first scatter reduce and then // gather. for (int k = 0; k < 2; k++) { for (int i = 0; i < size_ - 1; i++) { - size_t send_start = send_segment * segment_size; - size_t send_stop = + int64_t send_start = send_segment * segment_size; + int64_t send_stop = std::min((send_segment + 1) * segment_size, data_size); - size_t recv_start = recv_segment * segment_size; - size_t recv_stop = + int64_t recv_start = recv_segment * segment_size; + int64_t recv_stop = std::min((recv_segment + 1) * segment_size, data_size); - for (size_t j = 0; j < n_packets; j++) { + for (int64_t j = 0; j < n_packets; j++) { send_plan.emplace_back( std::min(send_start + j * BUFFER_SIZE, send_stop), std::min(send_start + (j + 1) * BUFFER_SIZE, send_stop)); @@ -864,18 +869,18 @@ class RingGroup : public GroupImpl { // Running the plan is fairly simple, we keep a send and a recv in flight // while doing the summation. T* recv_buffers[ALL_SUM_BUFFERS]; - for (int i = 0; i < ALL_SUM_BUFFERS; i++) { + for (int64_t i = 0; i < ALL_SUM_BUFFERS; i++) { recv_buffers[i] = buffer + i * BUFFER_SIZE; } std::future sends[2], recvs[2]; int a = 0; int b = (n_packets > 1) ? 1 : 0; - for (int i = 0, j = -b; i < send_plan.size(); j++, i++) { + for (int i = 0, j = -b; i < std::ssize(send_plan); j++, i++) { sends[a] = comm_.send( socket_send, data + send_plan[i].first, send_plan[i].second - send_plan[i].first); - if (2 * i < send_plan.size()) { + if (2 * i < std::ssize(send_plan)) { recvs[a] = comm_.recv( socket_recv, recv_buffers[i % ALL_SUM_BUFFERS], @@ -890,7 +895,7 @@ class RingGroup : public GroupImpl { if (j >= 0) { sends[b].wait(); recvs[b].wait(); - if (2 * j < send_plan.size()) { + if (2 * j < std::ssize(send_plan)) { reduce_op( recv_buffers[j % ALL_SUM_BUFFERS], data + recv_plan[j].first, @@ -907,8 +912,8 @@ class RingGroup : public GroupImpl { void all_gather_impl( const char* input, char* output, - size_t input_size, - size_t data_size, + int64_t input_size, + int64_t data_size, int socket_right, int socket_left, int direction) { @@ -941,11 +946,11 @@ class RingGroup : public GroupImpl { } void - send(const std::vector& sockets, const char* data, size_t data_size) { - size_t segment_size = - std::max(size_t(1024), ceildiv(data_size, sockets.size())); + send(const std::vector& sockets, const char* data, int64_t data_size) { + int64_t segment_size = + std::max(1024, ceildiv(data_size, std::ssize(sockets))); std::vector> sends; - for (int i = 0; i < sockets.size(); i++) { + for (int i = 0; i < std::ssize(sockets); i++) { if (i * segment_size >= data_size) { break; } @@ -959,11 +964,11 @@ class RingGroup : public GroupImpl { } } - void recv(const std::vector& sockets, char* data, size_t data_size) { - size_t segment_size = - std::max(size_t(1024), ceildiv(data_size, sockets.size())); + void recv(const std::vector& sockets, char* data, int64_t data_size) { + int64_t segment_size = + std::max(1024, ceildiv(data_size, std::ssize(sockets))); std::vector> recvs; - for (int i = 0; i < sockets.size(); i++) { + for (int i = 0; i < std::ssize(sockets); i++) { if (i * segment_size >= data_size) { break; }