Ring update (#1885)

This commit is contained in:
Angelos Katharopoulos 2025-02-20 14:32:31 -08:00 committed by GitHub
parent 0ebc8a3d25
commit 10b271d963
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 418 additions and 338 deletions

View File

@ -1,15 +1,19 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <arpa/inet.h> #include <arpa/inet.h>
#include <fcntl.h>
#include <netdb.h> #include <netdb.h>
#include <sys/socket.h> #include <sys/socket.h>
#include <unistd.h> #include <unistd.h>
#include <chrono> #include <chrono>
#include <fstream> #include <fstream>
#include <future>
#include <iostream> #include <iostream>
#include <list>
#include <sstream> #include <sstream>
#include <thread> #include <thread>
#include <unordered_map>
#include <json.hpp> #include <json.hpp>
@ -80,53 +84,17 @@
namespace mlx::core::distributed::ring { namespace mlx::core::distributed::ring {
constexpr const size_t PACKET_SIZE = 262144; constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024;
constexpr const size_t ALL_SUM_BUFFERS = 2;
constexpr const int CONN_ATTEMPTS = 5; constexpr const int CONN_ATTEMPTS = 5;
constexpr const int CONN_WAIT = 1000; constexpr const int CONN_WAIT = 1000;
using GroupImpl = mlx::core::distributed::detail::GroupImpl; using GroupImpl = mlx::core::distributed::detail::GroupImpl;
using json = nlohmann::json; using json = nlohmann::json;
using namespace std::chrono_literals;
namespace { namespace {
class Barrier {
public:
explicit Barrier(int n_threads)
: n_threads_(n_threads), count_(0), flag_(false) {}
void arrive_and_wait() {
std::unique_lock<std::mutex> lock(mtx_);
// Keep the flag that marks the current use of the barrier. The next use is
// going to have this flag flipped.
bool initial_flag = flag_;
// Increment the count
count_++;
// We are the last thread to arrive so reset the count, change the flag and
// notify everybody.
if (count_ == n_threads_) {
count_ = 0;
flag_ = !flag_;
cv_.notify_all();
}
// Wait for the rest to arrive
else {
cv_.wait(lock, [this, initial_flag]() { return initial_flag != flag_; });
}
}
private:
std::mutex mtx_;
std::condition_variable cv_;
int n_threads_;
int count_;
bool flag_; // we need this for sequential use of the barrier
};
template <typename T> template <typename T>
void log(std::ostream& os, T first) { void log(std::ostream& os, T first) {
os << first << std::endl; os << first << std::endl;
@ -151,6 +119,169 @@ decltype(T() * U()) ceildiv(T a, U b) {
return (a + b - 1) / b; return (a + b - 1) / b;
} }
class SocketThread {
public:
SocketThread(int fd) : fd_(fd), stop_(false) {
worker_ = std::thread(&SocketThread::worker, this);
int flags = fcntl(fd, F_GETFL, 0);
fcntl(fd, F_SETFL, flags | O_NONBLOCK);
}
~SocketThread() {
stop_ = true;
condition_.notify_all();
worker_.join();
int flags = fcntl(fd_, F_GETFL, 0);
fcntl(fd_, F_SETFL, flags & ~O_NONBLOCK);
}
template <typename T>
std::future<void> send(T* buffer, size_t size) {
return send_impl(reinterpret_cast<char*>(buffer), size * sizeof(T));
}
template <typename T>
std::future<void> recv(T* buffer, size_t size) {
return recv_impl(reinterpret_cast<char*>(buffer), size * sizeof(T));
}
private:
struct SocketTask {
SocketTask(void* b, size_t s, std::promise<void>&& p)
: buffer(b), size(s), promise(std::move(p)) {}
SocketTask(SocketTask&& t)
: buffer(t.buffer), size(t.size), promise(std::move(t.promise)) {}
void* buffer;
size_t size;
std::promise<void> promise;
};
std::future<void> send_impl(char* buffer, size_t size) {
std::promise<void> send_completed_promise;
auto send_completed_future = send_completed_promise.get_future();
if (size == 0) {
send_completed_promise.set_value();
return send_completed_future;
}
{
std::unique_lock lock(queue_mutex_);
sends_.emplace_back(
SocketTask(buffer, size, std::move(send_completed_promise)));
}
condition_.notify_one();
return send_completed_future;
}
std::future<void> recv_impl(char* buffer, size_t size) {
std::promise<void> recv_completed_promise;
auto recv_completed_future = recv_completed_promise.get_future();
if (size == 0) {
recv_completed_promise.set_value();
return recv_completed_future;
}
{
std::unique_lock lock(queue_mutex_);
recvs_.emplace_back(
SocketTask(buffer, size, std::move(recv_completed_promise)));
}
condition_.notify_one();
return recv_completed_future;
}
bool have_tasks() {
return !(sends_.empty() && recvs_.empty());
}
void worker() {
bool delete_recv = false;
bool delete_send = false;
while (true) {
{
std::unique_lock lock(queue_mutex_);
if (delete_recv) {
recvs_.front().promise.set_value();
recvs_.pop_front();
delete_recv = false;
}
if (delete_send) {
sends_.front().promise.set_value();
sends_.pop_front();
delete_send = false;
}
if (stop_) {
return;
}
if (!have_tasks()) {
condition_.wait(lock, [this] { return stop_ || have_tasks(); });
if (stop_) {
return;
}
}
}
if (!recvs_.empty()) {
auto& task = recvs_.front();
ssize_t r = ::recv(fd_, task.buffer, task.size, 0);
if (r >= 0) {
task.buffer = static_cast<char*>(task.buffer) + r;
task.size -= r;
delete_recv = task.size == 0;
} else if (errno != EAGAIN) {
log_info(
true, "Receiving from socket", fd_, "failed with errno", errno);
return;
}
}
if (!sends_.empty()) {
auto& task = sends_.front();
ssize_t r = ::send(fd_, task.buffer, task.size, 0);
if (r >= 0) {
task.buffer = static_cast<char*>(task.buffer) + r;
task.size -= r;
delete_send = task.size == 0;
} else if (errno != EAGAIN) {
log_info(true, "Sending to socket", fd_, "failed with errno", errno);
return;
}
}
}
}
int fd_;
bool stop_;
std::thread worker_;
std::mutex queue_mutex_;
std::condition_variable condition_;
std::list<SocketTask> sends_;
std::list<SocketTask> recvs_;
};
class CommunicationThreads {
public:
void add(const std::vector<int>& sockets) {
for (int sock : sockets) {
threads_.emplace(sock, sock);
}
}
template <typename T>
std::future<void> send(int socket, T* buffer, size_t size) {
return threads_.at(socket).send<T>(buffer, size);
}
template <typename T>
std::future<void> recv(int socket, T* buffer, size_t size) {
return threads_.at(socket).recv<T>(buffer, size);
}
private:
std::unordered_map<int, SocketThread> threads_;
};
struct address_t { struct address_t {
sockaddr_storage addr; sockaddr_storage addr;
socklen_t len; socklen_t len;
@ -378,140 +509,6 @@ void sum_inplace(const T* input, T* output, size_t N) {
} }
} }
template <typename T>
void _send(int sock, T* data, size_t start, size_t stop) {
if (stop <= start) {
return;
}
data += start;
size_t len = (stop - start) * sizeof(T);
const char* buffer = (const char*)data;
while (len > 0) {
ssize_t r = send(sock, buffer, len, 0);
if (r <= 0) {
std::ostringstream msg;
msg << "Send of " << len << " bytes failed (errno: " << errno << ")";
throw std::runtime_error(msg.str());
}
buffer += r;
len -= r;
}
}
template <typename T>
void _recv(int sock, T* data, size_t start, size_t stop) {
if (stop <= start) {
return;
}
data += start;
size_t len = (stop - start) * sizeof(T);
char* buffer = (char*)data;
while (len > 0) {
ssize_t r = recv(sock, buffer, len, 0);
if (r <= 0) {
std::ostringstream msg;
msg << "Recv of " << len << " bytes failed (errno: " << errno << ")";
throw std::runtime_error(msg.str());
}
buffer += r;
len -= r;
}
}
template <typename T>
void _recv_sum(int sock, T* data, size_t start, size_t stop) {
if (stop <= start) {
return;
}
data += start;
char buffer[PACKET_SIZE];
size_t len = (stop - start) * sizeof(T);
while (len > 0) {
ssize_t r = 0;
do {
ssize_t partial_r =
recv(sock, buffer + r, std::min(len, PACKET_SIZE) - r, 0);
if (partial_r <= 0) {
std::ostringstream msg;
msg << "Recv of " << len << " bytes failed (errno: " << errno << ")";
throw std::runtime_error(msg.str());
}
r += partial_r;
} while (r % sizeof(T));
sum_inplace((const T*)buffer, data, r / sizeof(T));
data += r / sizeof(T);
len -= r;
}
}
template <typename T>
void ring_send(
Barrier& barrier,
int socket,
int rank,
int size,
T* data,
size_t data_size,
int direction = -1) {
// We split the data into `size_` segments of size `segment_size`
size_t segment_size = ceildiv(data_size, size);
// Initial segment
int segment = rank;
// 1st send
for (int i = 0; i < size - 1; i++) {
size_t start = segment * segment_size;
size_t stop = std::min((segment + 1) * segment_size, data_size);
_send<T>(socket, data, start, stop);
barrier.arrive_and_wait();
segment = (segment + size + direction) % size;
}
// 2nd send
for (int i = 0; i < size - 1; i++) {
size_t start = segment * segment_size;
size_t stop = std::min((segment + 1) * segment_size, data_size);
_send<T>(socket, data, start, stop);
barrier.arrive_and_wait();
segment = (segment + size + direction) % size;
}
}
template <typename T>
void ring_recv_sum(
Barrier& barrier,
int socket,
int rank,
int size,
T* data,
size_t data_size,
int direction = -1) {
// We split the data into `size_` segments of size `segment_size`
size_t segment_size = ceildiv(data_size, size);
// Initial segment
int segment = (rank + size + direction) % size;
// Recv sum
for (int i = 0; i < size - 1; i++) {
size_t start = segment * segment_size;
size_t stop = std::min((segment + 1) * segment_size, data_size);
_recv_sum<T>(socket, data, start, stop);
barrier.arrive_and_wait();
segment = (segment + size + direction) % size;
}
// Recv
for (int i = 0; i < size - 1; i++) {
size_t start = segment * segment_size;
size_t stop = std::min((segment + 1) * segment_size, data_size);
_recv<T>(socket, data, start, stop);
barrier.arrive_and_wait();
segment = (segment + size + direction) % size;
}
}
} // namespace } // namespace
class RingGroup : public GroupImpl { class RingGroup : public GroupImpl {
@ -530,50 +527,59 @@ class RingGroup : public GroupImpl {
// first and accept after. // first and accept after.
if (rank_ < connect_to) { if (rank_ < connect_to) {
log_info(verbose_, "Rank", rank_, "accepting"); log_info(verbose_, "Rank", rank_, "accepting");
recv_sockets_ = std::move(accept_connections(nodes[rank_])); sockets_left_ = std::move(accept_connections(nodes[rank_]));
log_info(verbose_, "Rank", rank_, "connecting to", connect_to); log_info(verbose_, "Rank", rank_, "connecting to", connect_to);
send_sockets_ = std::move(make_connections(nodes[connect_to], verbose)); sockets_right_ = std::move(make_connections(nodes[connect_to], verbose));
} else { } else {
log_info(verbose_, "Rank", rank_, "connecting to", connect_to); log_info(verbose_, "Rank", rank_, "connecting to", connect_to);
send_sockets_ = std::move(make_connections(nodes[connect_to], verbose)); sockets_right_ = std::move(make_connections(nodes[connect_to], verbose));
log_info(verbose_, "Rank", rank_, "accepting"); log_info(verbose_, "Rank", rank_, "accepting");
recv_sockets_ = std::move(accept_connections(nodes[rank_])); sockets_left_ = std::move(accept_connections(nodes[rank_]));
} }
// Failure if we couldn't make send or recv sockets // Failure if we couldn't make right or left sockets
if (send_sockets_.empty()) { if (sockets_right_.empty()) {
std::ostringstream msg; std::ostringstream msg;
msg << "[ring] Rank " << rank_ << " has no send sockets."; msg << "[ring] Rank " << rank_ << " has no sockets to the right.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (recv_sockets_.empty()) { if (sockets_left_.empty()) {
std::ostringstream msg; std::ostringstream msg;
msg << "[ring] Rank " << rank_ << " has no recv sockets."; msg << "[ring] Rank " << rank_ << " has no sockets to the left.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
// The following could be relaxed since we can define non-homogeneous rings // The following could be relaxed since we can define non-homogeneous rings
// but it makes things a bit simpler for now. // but it makes things a bit simpler for now.
if (send_sockets_.size() != recv_sockets_.size()) { if (sockets_right_.size() != sockets_left_.size()) {
std::ostringstream msg; std::ostringstream msg;
msg << "[ring] It is required to have as many connections to the left as " msg << "[ring] It is required to have as many connections to the left as "
<< "to the right but rank " << rank_ << " has " << "to the right but rank " << rank_ << " has "
<< send_sockets_.size() << " connections to the right and " << sockets_right_.size() << " connections to the right and "
<< recv_sockets_.size() << " to the left."; << sockets_left_.size() << " to the left.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
// Start the necessary threads for completely parallel operation on all // Start the all reduce threads. One all reduce per direction per ring.
// channels. One thread to send, one to receive per socket. pool_.resize(sockets_right_.size() + sockets_left_.size());
pool_.resize(send_sockets_.size() * 2 * 2);
// Create a communication thread per socket. This also converts them to
// non-blocking.
comm_.add(sockets_right_);
comm_.add(sockets_left_);
// Allocate buffers for the all sum
buffers_.resize(
(sockets_right_.size() + sockets_left_.size()) * ALL_SUM_BUFFERS *
ALL_SUM_SIZE);
} }
~RingGroup() { ~RingGroup() {
for (auto s : send_sockets_) { for (auto s : sockets_right_) {
shutdown(s, 2); shutdown(s, 2);
close(s); close(s);
} }
for (auto s : recv_sockets_) { for (auto s : sockets_left_) {
shutdown(s, 2); shutdown(s, 2);
close(s); close(s);
} }
@ -594,14 +600,42 @@ class RingGroup : public GroupImpl {
std::shared_ptr<GroupImpl> split(int color, int key = -1) override { std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
throw std::runtime_error("[ring] Group split not supported."); throw std::runtime_error("[ring] Group split not supported.");
} }
void all_gather(const array& input, array& output) override { void all_gather(const array& input, array& output) override {
throw std::runtime_error("[ring] All gather not supported."); throw std::runtime_error("[ring] All gather not supported.");
} }
void send(const array& input, int dst) override {
throw std::runtime_error("[ring] Send not supported."); void send(const array& input_, int dst) override {
// Make sure that the input is row contiguous
array input = ensure_row_contiguous(input_);
int right = (rank_ + 1) % size_;
int left = (rank_ + size_ - 1) % size_;
if (dst == right) {
send(sockets_right_, input.data<char>(), input.nbytes());
} else if (dst == left) {
send(sockets_left_, input.data<char>(), input.nbytes());
} else {
std::ostringstream msg;
msg << "[ring] Send only supported to direct neighbors "
<< "but tried to send to " << dst << " from " << rank_ << std::endl;
throw std::runtime_error(msg.str());
} }
}
void recv(array& out, int src) override { void recv(array& out, int src) override {
throw std::runtime_error("[ring] Recv not supported."); int right = (rank_ + 1) % size_;
int left = (rank_ + size_ - 1) % size_;
if (src == right) {
recv(sockets_right_, out.data<char>(), out.nbytes());
} else if (src == left) {
recv(sockets_left_, out.data<char>(), out.nbytes());
} else {
std::ostringstream msg;
msg << "[ring] Recv only supported from direct neighbors "
<< "but tried to recv from " << src << " to " << rank_ << std::endl;
throw std::runtime_error(msg.str());
}
} }
private: private:
@ -613,7 +647,8 @@ class RingGroup : public GroupImpl {
// If the input data cannot be split into size_ segments then copy it and // If the input data cannot be split into size_ segments then copy it and
// all reduce a local buffer prefilled with 0s. // all reduce a local buffer prefilled with 0s.
if (input.size() < size_) { if (input.size() < size_) {
// TODO: Maybe allocate dynamically so we don't have the constraint below? // TODO: Maybe allocate dynamically so we don't have the constraint
// below?
if (input.itemsize() * size_ > 1024) { if (input.itemsize() * size_ > 1024) {
std::ostringstream msg; std::ostringstream msg;
msg << "Can't perform the ring all reduce of " << output.size() msg << "Can't perform the ring all reduce of " << output.size()
@ -621,31 +656,16 @@ class RingGroup : public GroupImpl {
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
std::future<void> sent, recvd;
auto barrier = std::make_unique<Barrier>(2);
char buffer[1024]; char buffer[1024];
std::memset(buffer, 0, size_ * input.itemsize()); std::memset(buffer, 0, size_ * input.itemsize());
std::memcpy(buffer, input.data<char>(), input.nbytes()); std::memcpy(buffer, input.data<char>(), input.nbytes());
sent = pool_.enqueue( all_sum_impl<T>(
ring_send<T>, reinterpret_cast<T*>(buffers_.data()),
std::reference_wrapper(*barrier), reinterpret_cast<T*>(buffer),
send_sockets_[0],
rank_,
size_,
(T*)buffer,
size_, size_,
sockets_right_[0],
sockets_left_[0],
-1); -1);
recvd = pool_.enqueue(
ring_recv_sum<T>,
std::reference_wrapper(*barrier),
recv_sockets_[0],
rank_,
size_,
(T*)buffer,
size_,
-1);
sent.wait();
recvd.wait();
std::memcpy(output.data<char>(), buffer, output.nbytes()); std::memcpy(output.data<char>(), buffer, output.nbytes());
return; return;
} }
@ -655,137 +675,155 @@ class RingGroup : public GroupImpl {
std::memcpy(output.data<char>(), input.data<char>(), input.nbytes()); std::memcpy(output.data<char>(), input.data<char>(), input.nbytes());
} }
// All reduce in place. We have `send_channels_.size()` bidirectional // Split the all reduces so that each member has at least 1 buffer to
// channels so let's split the message up and perform as many parallel // send/recv per segment.
// ring-reductions as possible. constexpr size_t min_send_size = 262144;
std::vector<std::future<void>> reductions; size_t n_reduces = std::max(
std::vector<std::unique_ptr<Barrier>> barriers;
size_t packets = ceildiv(output.size(), size_ * PACKET_SIZE);
// Large all reduce territory so let's use all we got
if (packets >= 2 * send_sockets_.size()) {
size_t segment = ceildiv(output.size(), 2 * send_sockets_.size());
for (int i = 0; i < send_sockets_.size(); i++) {
// 1st ring reduce
barriers.emplace_back(std::make_unique<Barrier>(2));
reductions.push_back(pool_.enqueue(
ring_send<T>,
std::reference_wrapper(*barriers.back()),
send_sockets_[i],
rank_,
size_,
output.data<T>() + 2 * i * segment,
std::min(output.size() - 2 * i * segment, segment),
-1));
reductions.push_back(pool_.enqueue(
ring_recv_sum<T>,
std::reference_wrapper(*barriers.back()),
recv_sockets_[i],
rank_,
size_,
output.data<T>() + 2 * i * segment,
std::min(output.size() - 2 * i * segment, segment),
-1));
// 2nd ring reduce
barriers.emplace_back(std::make_unique<Barrier>(2));
reductions.push_back(pool_.enqueue(
ring_send<T>,
std::reference_wrapper(*barriers.back()),
recv_sockets_[i],
rank_,
size_,
output.data<T>() + (2 * i + 1) * segment,
std::min(output.size() - (2 * i + 1) * segment, segment),
1));
reductions.push_back(pool_.enqueue(
ring_recv_sum<T>,
std::reference_wrapper(*barriers.back()),
send_sockets_[i],
rank_,
size_,
output.data<T>() + (2 * i + 1) * segment,
std::min(output.size() - (2 * i + 1) * segment, segment),
1));
}
}
// At least 2 reductions so we can be from small to medium
else if (packets > 1) {
size_t segment = ceildiv(output.size(), packets);
for (int i = 0; i < send_sockets_.size(); i++) {
barriers.emplace_back(std::make_unique<Barrier>(2));
reductions.push_back(pool_.enqueue(
ring_send<T>,
std::reference_wrapper(*barriers.back()),
send_sockets_[i],
rank_,
size_,
output.data<T>() + i * segment,
std::min(output.size() - i * segment, segment),
-1));
reductions.push_back(pool_.enqueue(
ring_recv_sum<T>,
std::reference_wrapper(*barriers.back()),
recv_sockets_[i],
rank_,
size_,
output.data<T>() + i * segment,
std::min(output.size() - i * segment, segment),
-1));
}
for (int i = 0; i < packets - send_sockets_.size(); i++) {
barriers.emplace_back(std::make_unique<Barrier>(2));
reductions.push_back(pool_.enqueue(
ring_send<T>,
std::reference_wrapper(*barriers.back()),
recv_sockets_[i],
rank_,
size_,
output.data<T>() + (send_sockets_.size() + i) * segment,
std::min( std::min(
output.size() - (send_sockets_.size() + i) * segment, segment), sockets_right_.size() + sockets_left_.size(),
1)); output.nbytes() / (size_ * min_send_size)),
reductions.push_back(pool_.enqueue( 1UL);
ring_recv_sum<T>, size_t step = ceildiv(output.size(), n_reduces);
std::reference_wrapper(*barriers.back()), std::vector<std::future<void>> all_sums;
send_sockets_[i],
rank_, for (int i = 0; i < n_reduces; i++) {
size_, all_sums.emplace_back(pool_.enqueue(std::bind(
output.data<T>() + (send_sockets_.size() + i) * segment, &RingGroup::all_sum_impl<T>,
std::min( this,
output.size() - (send_sockets_.size() + i) * segment, segment), reinterpret_cast<T*>(
1)); buffers_.data() + i * ALL_SUM_SIZE * ALL_SUM_BUFFERS),
output.data<T>() + i * step,
std::min(output.size(), (i + 1) * step) - i * step,
sockets_right_[i / 2],
sockets_left_[i / 2],
(i % 2) ? -1 : 1)));
}
for (auto& f : all_sums) {
f.wait();
} }
} }
// Small reduction which won't really benefit much from parallelization. template <typename T>
// TODO: Verify that this is true cause PACKET_SIZE * size_ can still be a void all_sum_impl(
// fairly large array. T* buffer,
else { T* data,
barriers.emplace_back(std::make_unique<Barrier>(2)); size_t data_size,
reductions.push_back(pool_.enqueue( int socket_right,
ring_send<T>, int socket_left,
std::reference_wrapper(*barriers.back()), int direction) {
send_sockets_[0], // Choose which socket we send to and recv from
rank_, int socket_send = (direction < 0) ? socket_right : socket_left;
size_, int socket_recv = (direction < 0) ? socket_left : socket_right;
output.data<T>(),
output.size(), // We split the data into `size_` segments of size `segment_size` and each
-1)); // of these in smaller segments of ALL_SUM_SIZE which we 'll call packets.
reductions.push_back(pool_.enqueue( size_t segment_size = ceildiv(data_size, size_);
ring_recv_sum<T>, size_t BUFFER_SIZE =
std::reference_wrapper(*barriers.back()), std::max(32768UL, std::min(ALL_SUM_SIZE / sizeof(T), segment_size / 2));
recv_sockets_[0], size_t n_packets = ceildiv(segment_size, BUFFER_SIZE);
rank_,
size_, // Initial segments
output.data<T>(), int send_segment = rank_;
output.size(), int recv_segment = (rank_ + direction + size_) % size_;
-1));
// Plan the whole reduce in terms of sends and recvs as indices in data.
// It makes the actual async send and recv a bit simpler to follow when
// there are less offset calculations around.
std::vector<std::pair<size_t, size_t>> send_plan;
std::vector<std::pair<size_t, size_t>> recv_plan;
// Two times the same send/recv operations, first scatter reduce and then
// gather.
for (int k = 0; k < 2; k++) {
for (int i = 0; i < size_ - 1; i++) {
size_t send_start = send_segment * segment_size;
size_t send_stop =
std::min((send_segment + 1) * segment_size, data_size);
size_t recv_start = recv_segment * segment_size;
size_t recv_stop =
std::min((recv_segment + 1) * segment_size, data_size);
for (size_t j = 0; j < n_packets; j++) {
send_plan.emplace_back(
std::min(send_start + j * BUFFER_SIZE, send_stop),
std::min(send_start + (j + 1) * BUFFER_SIZE, send_stop));
recv_plan.emplace_back(
std::min(recv_start + j * BUFFER_SIZE, recv_stop),
std::min(recv_start + (j + 1) * BUFFER_SIZE, recv_stop));
} }
// Wait for the reductions to finish. send_segment = (send_segment + size_ + direction) % size_;
for (auto& f : reductions) { recv_segment = (recv_segment + size_ + direction) % size_;
}
}
// Running the plan is fairly simple, we keep a send and a recv in flight
// while doing the summation.
T* recv_buffers[ALL_SUM_BUFFERS];
for (int i = 0; i < ALL_SUM_BUFFERS; i++) {
recv_buffers[i] = buffer + i * BUFFER_SIZE;
}
std::future<void> sends[2], recvs[2];
int a = 0;
int b = (n_packets > 1) ? 1 : 0;
for (int i = 0, j = -b; i < send_plan.size(); j++, i++) {
sends[a] = comm_.send(
socket_send,
data + send_plan[i].first,
send_plan[i].second - send_plan[i].first);
if (2 * i < send_plan.size()) {
recvs[a] = comm_.recv(
socket_recv,
recv_buffers[i % ALL_SUM_BUFFERS],
recv_plan[i].second - recv_plan[i].first);
} else {
recvs[a] = comm_.recv(
socket_recv,
data + recv_plan[i].first,
recv_plan[i].second - recv_plan[i].first);
}
if (j >= 0) {
sends[b].wait();
recvs[b].wait();
if (2 * j < send_plan.size()) {
sum_inplace<T>(
recv_buffers[j % ALL_SUM_BUFFERS],
data + recv_plan[j].first,
recv_plan[j].second - recv_plan[j].first);
}
}
std::swap(a, b);
}
sends[b].wait();
recvs[b].wait();
}
void send(const std::vector<int>& sockets, char* data, size_t data_size) {
size_t segment_size = ceildiv(data_size, sockets.size());
std::vector<std::future<void>> sends;
for (int i = 0; i < sockets.size(); i++) {
sends.emplace_back(comm_.send(
sockets[i],
data + i * segment_size,
std::min(data_size, (i + 1) * segment_size) - i * segment_size));
}
for (auto& f : sends) {
f.wait();
}
}
void recv(const std::vector<int>& sockets, char* data, size_t data_size) {
size_t segment_size = ceildiv(data_size, sockets.size());
std::vector<std::future<void>> recvs;
for (int i = 0; i < sockets.size(); i++) {
recvs.emplace_back(comm_.recv(
sockets[i],
data + i * segment_size,
std::min(data_size, (i + 1) * segment_size) - i * segment_size));
}
for (auto& f : recvs) {
f.wait(); f.wait();
} }
} }
@ -796,9 +834,12 @@ class RingGroup : public GroupImpl {
bool verbose_; bool verbose_;
ThreadPool pool_; ThreadPool pool_;
CommunicationThreads comm_;
std::vector<int> send_sockets_; std::vector<int> sockets_right_;
std::vector<int> recv_sockets_; std::vector<int> sockets_left_;
std::vector<char> buffers_;
}; };
bool is_available() { bool is_available() {

View File

@ -56,6 +56,45 @@ class TestRingDistributed(mlx_tests.MLXTestCase):
maxrelerror = ((y - z).abs() / z.abs()).max() maxrelerror = ((y - z).abs() / z.abs()).max()
self.assertLessEqual(maxrelerror, rtol) self.assertLessEqual(maxrelerror, rtol)
def test_send_recv(self):
world = mx.distributed.init()
dtypes = [
mx.int8,
mx.uint8,
mx.int16,
mx.uint16,
mx.int32,
mx.uint32,
mx.float32,
mx.float16,
mx.bfloat16,
mx.complex64,
]
sizes = [
(7,),
(10,),
(1024,),
(1024, 1024),
]
key = mx.random.key(0)
right = (world.rank() + 1) % world.size()
left = (world.rank() + world.size() - 1) % world.size()
for dt in dtypes:
for sh in sizes:
x = (
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
).astype(dt)
if world.rank() % 2 == 0:
y = mx.distributed.send(x[world.rank()], right)
z = mx.distributed.recv_like(y, left)
mx.eval(y, z)
else:
z = mx.distributed.recv_like(x[world.rank()], left)
y = mx.distributed.send(x[world.rank()], right)
mx.eval(z, y)
self.assertTrue(mx.all(y == x[world.rank()]))
self.assertTrue(mx.all(z == x[left]))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()