mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Ring update (#1885)
This commit is contained in:
parent
0ebc8a3d25
commit
10b271d963
@ -1,15 +1,19 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#include <arpa/inet.h>
|
#include <arpa/inet.h>
|
||||||
|
#include <fcntl.h>
|
||||||
#include <netdb.h>
|
#include <netdb.h>
|
||||||
#include <sys/socket.h>
|
#include <sys/socket.h>
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
|
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <future>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <list>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
#include <json.hpp>
|
#include <json.hpp>
|
||||||
|
|
||||||
@ -80,53 +84,17 @@
|
|||||||
|
|
||||||
namespace mlx::core::distributed::ring {
|
namespace mlx::core::distributed::ring {
|
||||||
|
|
||||||
constexpr const size_t PACKET_SIZE = 262144;
|
constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024;
|
||||||
|
constexpr const size_t ALL_SUM_BUFFERS = 2;
|
||||||
constexpr const int CONN_ATTEMPTS = 5;
|
constexpr const int CONN_ATTEMPTS = 5;
|
||||||
constexpr const int CONN_WAIT = 1000;
|
constexpr const int CONN_WAIT = 1000;
|
||||||
|
|
||||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
|
using namespace std::chrono_literals;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class Barrier {
|
|
||||||
public:
|
|
||||||
explicit Barrier(int n_threads)
|
|
||||||
: n_threads_(n_threads), count_(0), flag_(false) {}
|
|
||||||
|
|
||||||
void arrive_and_wait() {
|
|
||||||
std::unique_lock<std::mutex> lock(mtx_);
|
|
||||||
|
|
||||||
// Keep the flag that marks the current use of the barrier. The next use is
|
|
||||||
// going to have this flag flipped.
|
|
||||||
bool initial_flag = flag_;
|
|
||||||
|
|
||||||
// Increment the count
|
|
||||||
count_++;
|
|
||||||
|
|
||||||
// We are the last thread to arrive so reset the count, change the flag and
|
|
||||||
// notify everybody.
|
|
||||||
if (count_ == n_threads_) {
|
|
||||||
count_ = 0;
|
|
||||||
flag_ = !flag_;
|
|
||||||
cv_.notify_all();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for the rest to arrive
|
|
||||||
else {
|
|
||||||
cv_.wait(lock, [this, initial_flag]() { return initial_flag != flag_; });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::mutex mtx_;
|
|
||||||
std::condition_variable cv_;
|
|
||||||
int n_threads_;
|
|
||||||
|
|
||||||
int count_;
|
|
||||||
bool flag_; // we need this for sequential use of the barrier
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void log(std::ostream& os, T first) {
|
void log(std::ostream& os, T first) {
|
||||||
os << first << std::endl;
|
os << first << std::endl;
|
||||||
@ -151,6 +119,169 @@ decltype(T() * U()) ceildiv(T a, U b) {
|
|||||||
return (a + b - 1) / b;
|
return (a + b - 1) / b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class SocketThread {
|
||||||
|
public:
|
||||||
|
SocketThread(int fd) : fd_(fd), stop_(false) {
|
||||||
|
worker_ = std::thread(&SocketThread::worker, this);
|
||||||
|
int flags = fcntl(fd, F_GETFL, 0);
|
||||||
|
fcntl(fd, F_SETFL, flags | O_NONBLOCK);
|
||||||
|
}
|
||||||
|
~SocketThread() {
|
||||||
|
stop_ = true;
|
||||||
|
condition_.notify_all();
|
||||||
|
worker_.join();
|
||||||
|
int flags = fcntl(fd_, F_GETFL, 0);
|
||||||
|
fcntl(fd_, F_SETFL, flags & ~O_NONBLOCK);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::future<void> send(T* buffer, size_t size) {
|
||||||
|
return send_impl(reinterpret_cast<char*>(buffer), size * sizeof(T));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::future<void> recv(T* buffer, size_t size) {
|
||||||
|
return recv_impl(reinterpret_cast<char*>(buffer), size * sizeof(T));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct SocketTask {
|
||||||
|
SocketTask(void* b, size_t s, std::promise<void>&& p)
|
||||||
|
: buffer(b), size(s), promise(std::move(p)) {}
|
||||||
|
SocketTask(SocketTask&& t)
|
||||||
|
: buffer(t.buffer), size(t.size), promise(std::move(t.promise)) {}
|
||||||
|
void* buffer;
|
||||||
|
size_t size;
|
||||||
|
std::promise<void> promise;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::future<void> send_impl(char* buffer, size_t size) {
|
||||||
|
std::promise<void> send_completed_promise;
|
||||||
|
auto send_completed_future = send_completed_promise.get_future();
|
||||||
|
if (size == 0) {
|
||||||
|
send_completed_promise.set_value();
|
||||||
|
return send_completed_future;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
std::unique_lock lock(queue_mutex_);
|
||||||
|
sends_.emplace_back(
|
||||||
|
SocketTask(buffer, size, std::move(send_completed_promise)));
|
||||||
|
}
|
||||||
|
condition_.notify_one();
|
||||||
|
return send_completed_future;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::future<void> recv_impl(char* buffer, size_t size) {
|
||||||
|
std::promise<void> recv_completed_promise;
|
||||||
|
auto recv_completed_future = recv_completed_promise.get_future();
|
||||||
|
if (size == 0) {
|
||||||
|
recv_completed_promise.set_value();
|
||||||
|
return recv_completed_future;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
std::unique_lock lock(queue_mutex_);
|
||||||
|
recvs_.emplace_back(
|
||||||
|
SocketTask(buffer, size, std::move(recv_completed_promise)));
|
||||||
|
}
|
||||||
|
condition_.notify_one();
|
||||||
|
return recv_completed_future;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool have_tasks() {
|
||||||
|
return !(sends_.empty() && recvs_.empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
void worker() {
|
||||||
|
bool delete_recv = false;
|
||||||
|
bool delete_send = false;
|
||||||
|
while (true) {
|
||||||
|
{
|
||||||
|
std::unique_lock lock(queue_mutex_);
|
||||||
|
|
||||||
|
if (delete_recv) {
|
||||||
|
recvs_.front().promise.set_value();
|
||||||
|
recvs_.pop_front();
|
||||||
|
delete_recv = false;
|
||||||
|
}
|
||||||
|
if (delete_send) {
|
||||||
|
sends_.front().promise.set_value();
|
||||||
|
sends_.pop_front();
|
||||||
|
delete_send = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (stop_) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!have_tasks()) {
|
||||||
|
condition_.wait(lock, [this] { return stop_ || have_tasks(); });
|
||||||
|
if (stop_) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!recvs_.empty()) {
|
||||||
|
auto& task = recvs_.front();
|
||||||
|
ssize_t r = ::recv(fd_, task.buffer, task.size, 0);
|
||||||
|
if (r >= 0) {
|
||||||
|
task.buffer = static_cast<char*>(task.buffer) + r;
|
||||||
|
task.size -= r;
|
||||||
|
delete_recv = task.size == 0;
|
||||||
|
} else if (errno != EAGAIN) {
|
||||||
|
log_info(
|
||||||
|
true, "Receiving from socket", fd_, "failed with errno", errno);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!sends_.empty()) {
|
||||||
|
auto& task = sends_.front();
|
||||||
|
ssize_t r = ::send(fd_, task.buffer, task.size, 0);
|
||||||
|
if (r >= 0) {
|
||||||
|
task.buffer = static_cast<char*>(task.buffer) + r;
|
||||||
|
task.size -= r;
|
||||||
|
delete_send = task.size == 0;
|
||||||
|
} else if (errno != EAGAIN) {
|
||||||
|
log_info(true, "Sending to socket", fd_, "failed with errno", errno);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int fd_;
|
||||||
|
bool stop_;
|
||||||
|
std::thread worker_;
|
||||||
|
std::mutex queue_mutex_;
|
||||||
|
std::condition_variable condition_;
|
||||||
|
std::list<SocketTask> sends_;
|
||||||
|
std::list<SocketTask> recvs_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class CommunicationThreads {
|
||||||
|
public:
|
||||||
|
void add(const std::vector<int>& sockets) {
|
||||||
|
for (int sock : sockets) {
|
||||||
|
threads_.emplace(sock, sock);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::future<void> send(int socket, T* buffer, size_t size) {
|
||||||
|
return threads_.at(socket).send<T>(buffer, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::future<void> recv(int socket, T* buffer, size_t size) {
|
||||||
|
return threads_.at(socket).recv<T>(buffer, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unordered_map<int, SocketThread> threads_;
|
||||||
|
};
|
||||||
|
|
||||||
struct address_t {
|
struct address_t {
|
||||||
sockaddr_storage addr;
|
sockaddr_storage addr;
|
||||||
socklen_t len;
|
socklen_t len;
|
||||||
@ -378,140 +509,6 @@ void sum_inplace(const T* input, T* output, size_t N) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void _send(int sock, T* data, size_t start, size_t stop) {
|
|
||||||
if (stop <= start) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
data += start;
|
|
||||||
size_t len = (stop - start) * sizeof(T);
|
|
||||||
const char* buffer = (const char*)data;
|
|
||||||
while (len > 0) {
|
|
||||||
ssize_t r = send(sock, buffer, len, 0);
|
|
||||||
if (r <= 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "Send of " << len << " bytes failed (errno: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
buffer += r;
|
|
||||||
len -= r;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void _recv(int sock, T* data, size_t start, size_t stop) {
|
|
||||||
if (stop <= start) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
data += start;
|
|
||||||
size_t len = (stop - start) * sizeof(T);
|
|
||||||
char* buffer = (char*)data;
|
|
||||||
while (len > 0) {
|
|
||||||
ssize_t r = recv(sock, buffer, len, 0);
|
|
||||||
if (r <= 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "Recv of " << len << " bytes failed (errno: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
buffer += r;
|
|
||||||
len -= r;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void _recv_sum(int sock, T* data, size_t start, size_t stop) {
|
|
||||||
if (stop <= start) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
data += start;
|
|
||||||
char buffer[PACKET_SIZE];
|
|
||||||
size_t len = (stop - start) * sizeof(T);
|
|
||||||
while (len > 0) {
|
|
||||||
ssize_t r = 0;
|
|
||||||
do {
|
|
||||||
ssize_t partial_r =
|
|
||||||
recv(sock, buffer + r, std::min(len, PACKET_SIZE) - r, 0);
|
|
||||||
if (partial_r <= 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "Recv of " << len << " bytes failed (errno: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
r += partial_r;
|
|
||||||
} while (r % sizeof(T));
|
|
||||||
sum_inplace((const T*)buffer, data, r / sizeof(T));
|
|
||||||
data += r / sizeof(T);
|
|
||||||
len -= r;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void ring_send(
|
|
||||||
Barrier& barrier,
|
|
||||||
int socket,
|
|
||||||
int rank,
|
|
||||||
int size,
|
|
||||||
T* data,
|
|
||||||
size_t data_size,
|
|
||||||
int direction = -1) {
|
|
||||||
// We split the data into `size_` segments of size `segment_size`
|
|
||||||
size_t segment_size = ceildiv(data_size, size);
|
|
||||||
|
|
||||||
// Initial segment
|
|
||||||
int segment = rank;
|
|
||||||
|
|
||||||
// 1st send
|
|
||||||
for (int i = 0; i < size - 1; i++) {
|
|
||||||
size_t start = segment * segment_size;
|
|
||||||
size_t stop = std::min((segment + 1) * segment_size, data_size);
|
|
||||||
_send<T>(socket, data, start, stop);
|
|
||||||
barrier.arrive_and_wait();
|
|
||||||
segment = (segment + size + direction) % size;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2nd send
|
|
||||||
for (int i = 0; i < size - 1; i++) {
|
|
||||||
size_t start = segment * segment_size;
|
|
||||||
size_t stop = std::min((segment + 1) * segment_size, data_size);
|
|
||||||
_send<T>(socket, data, start, stop);
|
|
||||||
barrier.arrive_and_wait();
|
|
||||||
segment = (segment + size + direction) % size;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void ring_recv_sum(
|
|
||||||
Barrier& barrier,
|
|
||||||
int socket,
|
|
||||||
int rank,
|
|
||||||
int size,
|
|
||||||
T* data,
|
|
||||||
size_t data_size,
|
|
||||||
int direction = -1) {
|
|
||||||
// We split the data into `size_` segments of size `segment_size`
|
|
||||||
size_t segment_size = ceildiv(data_size, size);
|
|
||||||
|
|
||||||
// Initial segment
|
|
||||||
int segment = (rank + size + direction) % size;
|
|
||||||
|
|
||||||
// Recv sum
|
|
||||||
for (int i = 0; i < size - 1; i++) {
|
|
||||||
size_t start = segment * segment_size;
|
|
||||||
size_t stop = std::min((segment + 1) * segment_size, data_size);
|
|
||||||
_recv_sum<T>(socket, data, start, stop);
|
|
||||||
barrier.arrive_and_wait();
|
|
||||||
segment = (segment + size + direction) % size;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Recv
|
|
||||||
for (int i = 0; i < size - 1; i++) {
|
|
||||||
size_t start = segment * segment_size;
|
|
||||||
size_t stop = std::min((segment + 1) * segment_size, data_size);
|
|
||||||
_recv<T>(socket, data, start, stop);
|
|
||||||
barrier.arrive_and_wait();
|
|
||||||
segment = (segment + size + direction) % size;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
class RingGroup : public GroupImpl {
|
class RingGroup : public GroupImpl {
|
||||||
@ -530,50 +527,59 @@ class RingGroup : public GroupImpl {
|
|||||||
// first and accept after.
|
// first and accept after.
|
||||||
if (rank_ < connect_to) {
|
if (rank_ < connect_to) {
|
||||||
log_info(verbose_, "Rank", rank_, "accepting");
|
log_info(verbose_, "Rank", rank_, "accepting");
|
||||||
recv_sockets_ = std::move(accept_connections(nodes[rank_]));
|
sockets_left_ = std::move(accept_connections(nodes[rank_]));
|
||||||
log_info(verbose_, "Rank", rank_, "connecting to", connect_to);
|
log_info(verbose_, "Rank", rank_, "connecting to", connect_to);
|
||||||
send_sockets_ = std::move(make_connections(nodes[connect_to], verbose));
|
sockets_right_ = std::move(make_connections(nodes[connect_to], verbose));
|
||||||
} else {
|
} else {
|
||||||
log_info(verbose_, "Rank", rank_, "connecting to", connect_to);
|
log_info(verbose_, "Rank", rank_, "connecting to", connect_to);
|
||||||
send_sockets_ = std::move(make_connections(nodes[connect_to], verbose));
|
sockets_right_ = std::move(make_connections(nodes[connect_to], verbose));
|
||||||
log_info(verbose_, "Rank", rank_, "accepting");
|
log_info(verbose_, "Rank", rank_, "accepting");
|
||||||
recv_sockets_ = std::move(accept_connections(nodes[rank_]));
|
sockets_left_ = std::move(accept_connections(nodes[rank_]));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Failure if we couldn't make send or recv sockets
|
// Failure if we couldn't make right or left sockets
|
||||||
if (send_sockets_.empty()) {
|
if (sockets_right_.empty()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[ring] Rank " << rank_ << " has no send sockets.";
|
msg << "[ring] Rank " << rank_ << " has no sockets to the right.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
if (recv_sockets_.empty()) {
|
if (sockets_left_.empty()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[ring] Rank " << rank_ << " has no recv sockets.";
|
msg << "[ring] Rank " << rank_ << " has no sockets to the left.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
// The following could be relaxed since we can define non-homogeneous rings
|
// The following could be relaxed since we can define non-homogeneous rings
|
||||||
// but it makes things a bit simpler for now.
|
// but it makes things a bit simpler for now.
|
||||||
if (send_sockets_.size() != recv_sockets_.size()) {
|
if (sockets_right_.size() != sockets_left_.size()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[ring] It is required to have as many connections to the left as "
|
msg << "[ring] It is required to have as many connections to the left as "
|
||||||
<< "to the right but rank " << rank_ << " has "
|
<< "to the right but rank " << rank_ << " has "
|
||||||
<< send_sockets_.size() << " connections to the right and "
|
<< sockets_right_.size() << " connections to the right and "
|
||||||
<< recv_sockets_.size() << " to the left.";
|
<< sockets_left_.size() << " to the left.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start the necessary threads for completely parallel operation on all
|
// Start the all reduce threads. One all reduce per direction per ring.
|
||||||
// channels. One thread to send, one to receive per socket.
|
pool_.resize(sockets_right_.size() + sockets_left_.size());
|
||||||
pool_.resize(send_sockets_.size() * 2 * 2);
|
|
||||||
|
// Create a communication thread per socket. This also converts them to
|
||||||
|
// non-blocking.
|
||||||
|
comm_.add(sockets_right_);
|
||||||
|
comm_.add(sockets_left_);
|
||||||
|
|
||||||
|
// Allocate buffers for the all sum
|
||||||
|
buffers_.resize(
|
||||||
|
(sockets_right_.size() + sockets_left_.size()) * ALL_SUM_BUFFERS *
|
||||||
|
ALL_SUM_SIZE);
|
||||||
}
|
}
|
||||||
|
|
||||||
~RingGroup() {
|
~RingGroup() {
|
||||||
for (auto s : send_sockets_) {
|
for (auto s : sockets_right_) {
|
||||||
shutdown(s, 2);
|
shutdown(s, 2);
|
||||||
close(s);
|
close(s);
|
||||||
}
|
}
|
||||||
for (auto s : recv_sockets_) {
|
for (auto s : sockets_left_) {
|
||||||
shutdown(s, 2);
|
shutdown(s, 2);
|
||||||
close(s);
|
close(s);
|
||||||
}
|
}
|
||||||
@ -594,14 +600,42 @@ class RingGroup : public GroupImpl {
|
|||||||
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||||
throw std::runtime_error("[ring] Group split not supported.");
|
throw std::runtime_error("[ring] Group split not supported.");
|
||||||
}
|
}
|
||||||
|
|
||||||
void all_gather(const array& input, array& output) override {
|
void all_gather(const array& input, array& output) override {
|
||||||
throw std::runtime_error("[ring] All gather not supported.");
|
throw std::runtime_error("[ring] All gather not supported.");
|
||||||
}
|
}
|
||||||
void send(const array& input, int dst) override {
|
|
||||||
throw std::runtime_error("[ring] Send not supported.");
|
void send(const array& input_, int dst) override {
|
||||||
|
// Make sure that the input is row contiguous
|
||||||
|
array input = ensure_row_contiguous(input_);
|
||||||
|
|
||||||
|
int right = (rank_ + 1) % size_;
|
||||||
|
int left = (rank_ + size_ - 1) % size_;
|
||||||
|
if (dst == right) {
|
||||||
|
send(sockets_right_, input.data<char>(), input.nbytes());
|
||||||
|
} else if (dst == left) {
|
||||||
|
send(sockets_left_, input.data<char>(), input.nbytes());
|
||||||
|
} else {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[ring] Send only supported to direct neighbors "
|
||||||
|
<< "but tried to send to " << dst << " from " << rank_ << std::endl;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void recv(array& out, int src) override {
|
void recv(array& out, int src) override {
|
||||||
throw std::runtime_error("[ring] Recv not supported.");
|
int right = (rank_ + 1) % size_;
|
||||||
|
int left = (rank_ + size_ - 1) % size_;
|
||||||
|
if (src == right) {
|
||||||
|
recv(sockets_right_, out.data<char>(), out.nbytes());
|
||||||
|
} else if (src == left) {
|
||||||
|
recv(sockets_left_, out.data<char>(), out.nbytes());
|
||||||
|
} else {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[ring] Recv only supported from direct neighbors "
|
||||||
|
<< "but tried to recv from " << src << " to " << rank_ << std::endl;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -613,7 +647,8 @@ class RingGroup : public GroupImpl {
|
|||||||
// If the input data cannot be split into size_ segments then copy it and
|
// If the input data cannot be split into size_ segments then copy it and
|
||||||
// all reduce a local buffer prefilled with 0s.
|
// all reduce a local buffer prefilled with 0s.
|
||||||
if (input.size() < size_) {
|
if (input.size() < size_) {
|
||||||
// TODO: Maybe allocate dynamically so we don't have the constraint below?
|
// TODO: Maybe allocate dynamically so we don't have the constraint
|
||||||
|
// below?
|
||||||
if (input.itemsize() * size_ > 1024) {
|
if (input.itemsize() * size_ > 1024) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "Can't perform the ring all reduce of " << output.size()
|
msg << "Can't perform the ring all reduce of " << output.size()
|
||||||
@ -621,31 +656,16 @@ class RingGroup : public GroupImpl {
|
|||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::future<void> sent, recvd;
|
|
||||||
auto barrier = std::make_unique<Barrier>(2);
|
|
||||||
char buffer[1024];
|
char buffer[1024];
|
||||||
std::memset(buffer, 0, size_ * input.itemsize());
|
std::memset(buffer, 0, size_ * input.itemsize());
|
||||||
std::memcpy(buffer, input.data<char>(), input.nbytes());
|
std::memcpy(buffer, input.data<char>(), input.nbytes());
|
||||||
sent = pool_.enqueue(
|
all_sum_impl<T>(
|
||||||
ring_send<T>,
|
reinterpret_cast<T*>(buffers_.data()),
|
||||||
std::reference_wrapper(*barrier),
|
reinterpret_cast<T*>(buffer),
|
||||||
send_sockets_[0],
|
|
||||||
rank_,
|
|
||||||
size_,
|
|
||||||
(T*)buffer,
|
|
||||||
size_,
|
size_,
|
||||||
|
sockets_right_[0],
|
||||||
|
sockets_left_[0],
|
||||||
-1);
|
-1);
|
||||||
recvd = pool_.enqueue(
|
|
||||||
ring_recv_sum<T>,
|
|
||||||
std::reference_wrapper(*barrier),
|
|
||||||
recv_sockets_[0],
|
|
||||||
rank_,
|
|
||||||
size_,
|
|
||||||
(T*)buffer,
|
|
||||||
size_,
|
|
||||||
-1);
|
|
||||||
sent.wait();
|
|
||||||
recvd.wait();
|
|
||||||
std::memcpy(output.data<char>(), buffer, output.nbytes());
|
std::memcpy(output.data<char>(), buffer, output.nbytes());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -655,137 +675,155 @@ class RingGroup : public GroupImpl {
|
|||||||
std::memcpy(output.data<char>(), input.data<char>(), input.nbytes());
|
std::memcpy(output.data<char>(), input.data<char>(), input.nbytes());
|
||||||
}
|
}
|
||||||
|
|
||||||
// All reduce in place. We have `send_channels_.size()` bidirectional
|
// Split the all reduces so that each member has at least 1 buffer to
|
||||||
// channels so let's split the message up and perform as many parallel
|
// send/recv per segment.
|
||||||
// ring-reductions as possible.
|
constexpr size_t min_send_size = 262144;
|
||||||
std::vector<std::future<void>> reductions;
|
size_t n_reduces = std::max(
|
||||||
std::vector<std::unique_ptr<Barrier>> barriers;
|
std::min(
|
||||||
size_t packets = ceildiv(output.size(), size_ * PACKET_SIZE);
|
sockets_right_.size() + sockets_left_.size(),
|
||||||
|
output.nbytes() / (size_ * min_send_size)),
|
||||||
|
1UL);
|
||||||
|
size_t step = ceildiv(output.size(), n_reduces);
|
||||||
|
std::vector<std::future<void>> all_sums;
|
||||||
|
|
||||||
// Large all reduce territory so let's use all we got
|
for (int i = 0; i < n_reduces; i++) {
|
||||||
if (packets >= 2 * send_sockets_.size()) {
|
all_sums.emplace_back(pool_.enqueue(std::bind(
|
||||||
size_t segment = ceildiv(output.size(), 2 * send_sockets_.size());
|
&RingGroup::all_sum_impl<T>,
|
||||||
for (int i = 0; i < send_sockets_.size(); i++) {
|
this,
|
||||||
// 1st ring reduce
|
reinterpret_cast<T*>(
|
||||||
barriers.emplace_back(std::make_unique<Barrier>(2));
|
buffers_.data() + i * ALL_SUM_SIZE * ALL_SUM_BUFFERS),
|
||||||
reductions.push_back(pool_.enqueue(
|
output.data<T>() + i * step,
|
||||||
ring_send<T>,
|
std::min(output.size(), (i + 1) * step) - i * step,
|
||||||
std::reference_wrapper(*barriers.back()),
|
sockets_right_[i / 2],
|
||||||
send_sockets_[i],
|
sockets_left_[i / 2],
|
||||||
rank_,
|
(i % 2) ? -1 : 1)));
|
||||||
size_,
|
}
|
||||||
output.data<T>() + 2 * i * segment,
|
for (auto& f : all_sums) {
|
||||||
std::min(output.size() - 2 * i * segment, segment),
|
f.wait();
|
||||||
-1));
|
}
|
||||||
reductions.push_back(pool_.enqueue(
|
}
|
||||||
ring_recv_sum<T>,
|
|
||||||
std::reference_wrapper(*barriers.back()),
|
|
||||||
recv_sockets_[i],
|
|
||||||
rank_,
|
|
||||||
size_,
|
|
||||||
output.data<T>() + 2 * i * segment,
|
|
||||||
std::min(output.size() - 2 * i * segment, segment),
|
|
||||||
-1));
|
|
||||||
|
|
||||||
// 2nd ring reduce
|
template <typename T>
|
||||||
barriers.emplace_back(std::make_unique<Barrier>(2));
|
void all_sum_impl(
|
||||||
reductions.push_back(pool_.enqueue(
|
T* buffer,
|
||||||
ring_send<T>,
|
T* data,
|
||||||
std::reference_wrapper(*barriers.back()),
|
size_t data_size,
|
||||||
recv_sockets_[i],
|
int socket_right,
|
||||||
rank_,
|
int socket_left,
|
||||||
size_,
|
int direction) {
|
||||||
output.data<T>() + (2 * i + 1) * segment,
|
// Choose which socket we send to and recv from
|
||||||
std::min(output.size() - (2 * i + 1) * segment, segment),
|
int socket_send = (direction < 0) ? socket_right : socket_left;
|
||||||
1));
|
int socket_recv = (direction < 0) ? socket_left : socket_right;
|
||||||
reductions.push_back(pool_.enqueue(
|
|
||||||
ring_recv_sum<T>,
|
// We split the data into `size_` segments of size `segment_size` and each
|
||||||
std::reference_wrapper(*barriers.back()),
|
// of these in smaller segments of ALL_SUM_SIZE which we 'll call packets.
|
||||||
send_sockets_[i],
|
size_t segment_size = ceildiv(data_size, size_);
|
||||||
rank_,
|
size_t BUFFER_SIZE =
|
||||||
size_,
|
std::max(32768UL, std::min(ALL_SUM_SIZE / sizeof(T), segment_size / 2));
|
||||||
output.data<T>() + (2 * i + 1) * segment,
|
size_t n_packets = ceildiv(segment_size, BUFFER_SIZE);
|
||||||
std::min(output.size() - (2 * i + 1) * segment, segment),
|
|
||||||
1));
|
// Initial segments
|
||||||
|
int send_segment = rank_;
|
||||||
|
int recv_segment = (rank_ + direction + size_) % size_;
|
||||||
|
|
||||||
|
// Plan the whole reduce in terms of sends and recvs as indices in data.
|
||||||
|
// It makes the actual async send and recv a bit simpler to follow when
|
||||||
|
// there are less offset calculations around.
|
||||||
|
std::vector<std::pair<size_t, size_t>> send_plan;
|
||||||
|
std::vector<std::pair<size_t, size_t>> recv_plan;
|
||||||
|
|
||||||
|
// Two times the same send/recv operations, first scatter reduce and then
|
||||||
|
// gather.
|
||||||
|
for (int k = 0; k < 2; k++) {
|
||||||
|
for (int i = 0; i < size_ - 1; i++) {
|
||||||
|
size_t send_start = send_segment * segment_size;
|
||||||
|
size_t send_stop =
|
||||||
|
std::min((send_segment + 1) * segment_size, data_size);
|
||||||
|
size_t recv_start = recv_segment * segment_size;
|
||||||
|
size_t recv_stop =
|
||||||
|
std::min((recv_segment + 1) * segment_size, data_size);
|
||||||
|
|
||||||
|
for (size_t j = 0; j < n_packets; j++) {
|
||||||
|
send_plan.emplace_back(
|
||||||
|
std::min(send_start + j * BUFFER_SIZE, send_stop),
|
||||||
|
std::min(send_start + (j + 1) * BUFFER_SIZE, send_stop));
|
||||||
|
recv_plan.emplace_back(
|
||||||
|
std::min(recv_start + j * BUFFER_SIZE, recv_stop),
|
||||||
|
std::min(recv_start + (j + 1) * BUFFER_SIZE, recv_stop));
|
||||||
|
}
|
||||||
|
|
||||||
|
send_segment = (send_segment + size_ + direction) % size_;
|
||||||
|
recv_segment = (recv_segment + size_ + direction) % size_;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// At least 2 reductions so we can be from small to medium
|
// Running the plan is fairly simple, we keep a send and a recv in flight
|
||||||
else if (packets > 1) {
|
// while doing the summation.
|
||||||
size_t segment = ceildiv(output.size(), packets);
|
T* recv_buffers[ALL_SUM_BUFFERS];
|
||||||
for (int i = 0; i < send_sockets_.size(); i++) {
|
for (int i = 0; i < ALL_SUM_BUFFERS; i++) {
|
||||||
barriers.emplace_back(std::make_unique<Barrier>(2));
|
recv_buffers[i] = buffer + i * BUFFER_SIZE;
|
||||||
reductions.push_back(pool_.enqueue(
|
|
||||||
ring_send<T>,
|
|
||||||
std::reference_wrapper(*barriers.back()),
|
|
||||||
send_sockets_[i],
|
|
||||||
rank_,
|
|
||||||
size_,
|
|
||||||
output.data<T>() + i * segment,
|
|
||||||
std::min(output.size() - i * segment, segment),
|
|
||||||
-1));
|
|
||||||
reductions.push_back(pool_.enqueue(
|
|
||||||
ring_recv_sum<T>,
|
|
||||||
std::reference_wrapper(*barriers.back()),
|
|
||||||
recv_sockets_[i],
|
|
||||||
rank_,
|
|
||||||
size_,
|
|
||||||
output.data<T>() + i * segment,
|
|
||||||
std::min(output.size() - i * segment, segment),
|
|
||||||
-1));
|
|
||||||
}
|
|
||||||
for (int i = 0; i < packets - send_sockets_.size(); i++) {
|
|
||||||
barriers.emplace_back(std::make_unique<Barrier>(2));
|
|
||||||
reductions.push_back(pool_.enqueue(
|
|
||||||
ring_send<T>,
|
|
||||||
std::reference_wrapper(*barriers.back()),
|
|
||||||
recv_sockets_[i],
|
|
||||||
rank_,
|
|
||||||
size_,
|
|
||||||
output.data<T>() + (send_sockets_.size() + i) * segment,
|
|
||||||
std::min(
|
|
||||||
output.size() - (send_sockets_.size() + i) * segment, segment),
|
|
||||||
1));
|
|
||||||
reductions.push_back(pool_.enqueue(
|
|
||||||
ring_recv_sum<T>,
|
|
||||||
std::reference_wrapper(*barriers.back()),
|
|
||||||
send_sockets_[i],
|
|
||||||
rank_,
|
|
||||||
size_,
|
|
||||||
output.data<T>() + (send_sockets_.size() + i) * segment,
|
|
||||||
std::min(
|
|
||||||
output.size() - (send_sockets_.size() + i) * segment, segment),
|
|
||||||
1));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
std::future<void> sends[2], recvs[2];
|
||||||
|
int a = 0;
|
||||||
|
int b = (n_packets > 1) ? 1 : 0;
|
||||||
|
for (int i = 0, j = -b; i < send_plan.size(); j++, i++) {
|
||||||
|
sends[a] = comm_.send(
|
||||||
|
socket_send,
|
||||||
|
data + send_plan[i].first,
|
||||||
|
send_plan[i].second - send_plan[i].first);
|
||||||
|
if (2 * i < send_plan.size()) {
|
||||||
|
recvs[a] = comm_.recv(
|
||||||
|
socket_recv,
|
||||||
|
recv_buffers[i % ALL_SUM_BUFFERS],
|
||||||
|
recv_plan[i].second - recv_plan[i].first);
|
||||||
|
} else {
|
||||||
|
recvs[a] = comm_.recv(
|
||||||
|
socket_recv,
|
||||||
|
data + recv_plan[i].first,
|
||||||
|
recv_plan[i].second - recv_plan[i].first);
|
||||||
|
}
|
||||||
|
|
||||||
// Small reduction which won't really benefit much from parallelization.
|
if (j >= 0) {
|
||||||
// TODO: Verify that this is true cause PACKET_SIZE * size_ can still be a
|
sends[b].wait();
|
||||||
// fairly large array.
|
recvs[b].wait();
|
||||||
else {
|
if (2 * j < send_plan.size()) {
|
||||||
barriers.emplace_back(std::make_unique<Barrier>(2));
|
sum_inplace<T>(
|
||||||
reductions.push_back(pool_.enqueue(
|
recv_buffers[j % ALL_SUM_BUFFERS],
|
||||||
ring_send<T>,
|
data + recv_plan[j].first,
|
||||||
std::reference_wrapper(*barriers.back()),
|
recv_plan[j].second - recv_plan[j].first);
|
||||||
send_sockets_[0],
|
}
|
||||||
rank_,
|
}
|
||||||
size_,
|
|
||||||
output.data<T>(),
|
std::swap(a, b);
|
||||||
output.size(),
|
|
||||||
-1));
|
|
||||||
reductions.push_back(pool_.enqueue(
|
|
||||||
ring_recv_sum<T>,
|
|
||||||
std::reference_wrapper(*barriers.back()),
|
|
||||||
recv_sockets_[0],
|
|
||||||
rank_,
|
|
||||||
size_,
|
|
||||||
output.data<T>(),
|
|
||||||
output.size(),
|
|
||||||
-1));
|
|
||||||
}
|
}
|
||||||
|
sends[b].wait();
|
||||||
|
recvs[b].wait();
|
||||||
|
}
|
||||||
|
|
||||||
// Wait for the reductions to finish.
|
void send(const std::vector<int>& sockets, char* data, size_t data_size) {
|
||||||
for (auto& f : reductions) {
|
size_t segment_size = ceildiv(data_size, sockets.size());
|
||||||
|
std::vector<std::future<void>> sends;
|
||||||
|
for (int i = 0; i < sockets.size(); i++) {
|
||||||
|
sends.emplace_back(comm_.send(
|
||||||
|
sockets[i],
|
||||||
|
data + i * segment_size,
|
||||||
|
std::min(data_size, (i + 1) * segment_size) - i * segment_size));
|
||||||
|
}
|
||||||
|
for (auto& f : sends) {
|
||||||
|
f.wait();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void recv(const std::vector<int>& sockets, char* data, size_t data_size) {
|
||||||
|
size_t segment_size = ceildiv(data_size, sockets.size());
|
||||||
|
std::vector<std::future<void>> recvs;
|
||||||
|
for (int i = 0; i < sockets.size(); i++) {
|
||||||
|
recvs.emplace_back(comm_.recv(
|
||||||
|
sockets[i],
|
||||||
|
data + i * segment_size,
|
||||||
|
std::min(data_size, (i + 1) * segment_size) - i * segment_size));
|
||||||
|
}
|
||||||
|
for (auto& f : recvs) {
|
||||||
f.wait();
|
f.wait();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -796,9 +834,12 @@ class RingGroup : public GroupImpl {
|
|||||||
bool verbose_;
|
bool verbose_;
|
||||||
|
|
||||||
ThreadPool pool_;
|
ThreadPool pool_;
|
||||||
|
CommunicationThreads comm_;
|
||||||
|
|
||||||
std::vector<int> send_sockets_;
|
std::vector<int> sockets_right_;
|
||||||
std::vector<int> recv_sockets_;
|
std::vector<int> sockets_left_;
|
||||||
|
|
||||||
|
std::vector<char> buffers_;
|
||||||
};
|
};
|
||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
|
@ -56,6 +56,45 @@ class TestRingDistributed(mlx_tests.MLXTestCase):
|
|||||||
maxrelerror = ((y - z).abs() / z.abs()).max()
|
maxrelerror = ((y - z).abs() / z.abs()).max()
|
||||||
self.assertLessEqual(maxrelerror, rtol)
|
self.assertLessEqual(maxrelerror, rtol)
|
||||||
|
|
||||||
|
def test_send_recv(self):
|
||||||
|
world = mx.distributed.init()
|
||||||
|
dtypes = [
|
||||||
|
mx.int8,
|
||||||
|
mx.uint8,
|
||||||
|
mx.int16,
|
||||||
|
mx.uint16,
|
||||||
|
mx.int32,
|
||||||
|
mx.uint32,
|
||||||
|
mx.float32,
|
||||||
|
mx.float16,
|
||||||
|
mx.bfloat16,
|
||||||
|
mx.complex64,
|
||||||
|
]
|
||||||
|
sizes = [
|
||||||
|
(7,),
|
||||||
|
(10,),
|
||||||
|
(1024,),
|
||||||
|
(1024, 1024),
|
||||||
|
]
|
||||||
|
key = mx.random.key(0)
|
||||||
|
right = (world.rank() + 1) % world.size()
|
||||||
|
left = (world.rank() + world.size() - 1) % world.size()
|
||||||
|
for dt in dtypes:
|
||||||
|
for sh in sizes:
|
||||||
|
x = (
|
||||||
|
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
|
||||||
|
).astype(dt)
|
||||||
|
if world.rank() % 2 == 0:
|
||||||
|
y = mx.distributed.send(x[world.rank()], right)
|
||||||
|
z = mx.distributed.recv_like(y, left)
|
||||||
|
mx.eval(y, z)
|
||||||
|
else:
|
||||||
|
z = mx.distributed.recv_like(x[world.rank()], left)
|
||||||
|
y = mx.distributed.send(x[world.rank()], right)
|
||||||
|
mx.eval(z, y)
|
||||||
|
self.assertTrue(mx.all(y == x[world.rank()]))
|
||||||
|
self.assertTrue(mx.all(z == x[left]))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user