WIP (distributed)

This commit is contained in:
Ronan Collobert
2025-10-30 16:25:11 -07:00
parent 45a8b226af
commit a1212b4e44
2 changed files with 67 additions and 62 deletions

View File

@@ -27,7 +27,7 @@ std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
} }
std::vector<array> AllReduce::jvp( std::vector<array> AllReduce::jvp(
const std::vector<array>& primals, const std::vector<array>& /* primals */,
const std::vector<array>& tangents, const std::vector<array>& tangents,
const std::vector<int>&) { const std::vector<int>&) {
switch (reduce_type_) { switch (reduce_type_) {
@@ -44,10 +44,10 @@ std::vector<array> AllReduce::jvp(
} }
std::vector<array> AllReduce::vjp( std::vector<array> AllReduce::vjp(
const std::vector<array>& primals, const std::vector<array>& /* primals */,
const std::vector<array>& cotangents, const std::vector<array>& cotangents,
const std::vector<int>&, const std::vector<int>&,
const std::vector<array>& outputs) { const std::vector<array>& /* outputs */) {
return cotangents; return cotangents;
} }
@@ -58,7 +58,7 @@ std::pair<std::vector<array>, std::vector<int>> AllGather::vmap(
} }
std::vector<array> AllGather::jvp( std::vector<array> AllGather::jvp(
const std::vector<array>& primals, const std::vector<array>& /* primals */,
const std::vector<array>& tangents, const std::vector<array>& tangents,
const std::vector<int>&) { const std::vector<int>&) {
return {all_gather(tangents[0], group(), stream())}; return {all_gather(tangents[0], group(), stream())};

View File

@@ -90,8 +90,8 @@
namespace mlx::core::distributed::ring { namespace mlx::core::distributed::ring {
constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024; constexpr const int64_t ALL_SUM_SIZE = 8 * 1024 * 1024;
constexpr const size_t ALL_SUM_BUFFERS = 2; constexpr const int64_t ALL_SUM_BUFFERS = 2;
constexpr const int CONN_ATTEMPTS = 5; constexpr const int CONN_ATTEMPTS = 5;
constexpr const int CONN_WAIT = 1000; constexpr const int CONN_WAIT = 1000;
@@ -141,27 +141,27 @@ class SocketThread {
} }
template <typename T> template <typename T>
std::future<void> send(const T* buffer, size_t size) { std::future<void> send(const T* buffer, int64_t size) {
return send_impl(reinterpret_cast<const char*>(buffer), size * sizeof(T)); return send_impl(reinterpret_cast<const char*>(buffer), size * sizeof(T));
} }
template <typename T> template <typename T>
std::future<void> recv(T* buffer, size_t size) { std::future<void> recv(T* buffer, int64_t size) {
return recv_impl(reinterpret_cast<char*>(buffer), size * sizeof(T)); return recv_impl(reinterpret_cast<char*>(buffer), size * sizeof(T));
} }
private: private:
struct SocketTask { struct SocketTask {
SocketTask(void* b, size_t s, std::promise<void>&& p) SocketTask(void* b, int64_t s, std::promise<void>&& p)
: buffer(b), size(s), promise(std::move(p)) {} : buffer(b), size(s), promise(std::move(p)) {}
SocketTask(SocketTask&& t) SocketTask(SocketTask&& t)
: buffer(t.buffer), size(t.size), promise(std::move(t.promise)) {} : buffer(t.buffer), size(t.size), promise(std::move(t.promise)) {}
void* buffer; void* buffer;
size_t size; int64_t size;
std::promise<void> promise; std::promise<void> promise;
}; };
std::future<void> send_impl(const char* buffer, size_t size) { std::future<void> send_impl(const char* buffer, int64_t size) {
std::promise<void> send_completed_promise; std::promise<void> send_completed_promise;
auto send_completed_future = send_completed_promise.get_future(); auto send_completed_future = send_completed_promise.get_future();
if (size == 0) { if (size == 0) {
@@ -178,7 +178,7 @@ class SocketThread {
return send_completed_future; return send_completed_future;
} }
std::future<void> recv_impl(char* buffer, size_t size) { std::future<void> recv_impl(char* buffer, int64_t size) {
std::promise<void> recv_completed_promise; std::promise<void> recv_completed_promise;
auto recv_completed_future = recv_completed_promise.get_future(); auto recv_completed_future = recv_completed_promise.get_future();
if (size == 0) { if (size == 0) {
@@ -232,7 +232,7 @@ class SocketThread {
if (!recvs_.empty()) { if (!recvs_.empty()) {
auto& task = recvs_.front(); auto& task = recvs_.front();
ssize_t r = ::recv(fd_, task.buffer, task.size, 0); int64_t r = ::recv(fd_, task.buffer, task.size, 0);
if (r > 0) { if (r > 0) {
task.buffer = static_cast<char*>(task.buffer) + r; task.buffer = static_cast<char*>(task.buffer) + r;
task.size -= r; task.size -= r;
@@ -246,7 +246,7 @@ class SocketThread {
} }
if (!sends_.empty()) { if (!sends_.empty()) {
auto& task = sends_.front(); auto& task = sends_.front();
ssize_t r = ::send(fd_, task.buffer, task.size, 0); int64_t r = ::send(fd_, task.buffer, task.size, 0);
if (r > 0) { if (r > 0) {
task.buffer = static_cast<char*>(task.buffer) + r; task.buffer = static_cast<char*>(task.buffer) + r;
task.size -= r; task.size -= r;
@@ -283,12 +283,12 @@ class CommunicationThreads {
} }
template <typename T> template <typename T>
std::future<void> send(int socket, T* buffer, size_t size) { std::future<void> send(int socket, T* buffer, int64_t size) {
return threads_.at(socket).send<T>(buffer, size); return threads_.at(socket).send<T>(buffer, size);
} }
template <typename T> template <typename T>
std::future<void> recv(int socket, T* buffer, size_t size) { std::future<void> recv(int socket, T* buffer, int64_t size) {
return threads_.at(socket).recv<T>(buffer, size); return threads_.at(socket).recv<T>(buffer, size);
} }
@@ -505,7 +505,7 @@ std::vector<int> make_connections(
} }
template <typename T> template <typename T>
struct SumOp { struct SumOp {
void operator()(const T* input, T* output, size_t N) { void operator()(const T* input, T* output, int64_t N) {
while (N-- > 0) { while (N-- > 0) {
*output += *input; *output += *input;
input++; input++;
@@ -516,7 +516,7 @@ struct SumOp {
template <typename T> template <typename T>
struct MaxOp { struct MaxOp {
void operator()(const T* input, T* output, size_t N) { void operator()(const T* input, T* output, int64_t N) {
while (N-- > 0) { while (N-- > 0) {
*output = std::max(*output, *input); *output = std::max(*output, *input);
input++; input++;
@@ -527,7 +527,7 @@ struct MaxOp {
template <typename T> template <typename T>
struct MinOp { struct MinOp {
void operator()(const T* input, T* output, size_t N) { void operator()(const T* input, T* output, int64_t N) {
while (N-- > 0) { while (N-- > 0) {
*output = std::min(*output, *input); *output = std::min(*output, *input);
input++; input++;
@@ -542,7 +542,7 @@ class RingGroup : public GroupImpl {
public: public:
RingGroup(int rank, std::vector<std::vector<address_t>> nodes, bool verbose) RingGroup(int rank, std::vector<std::vector<address_t>> nodes, bool verbose)
: rank_(rank), verbose_(verbose), pool_(0) { : rank_(rank), verbose_(verbose), pool_(0) {
if (rank_ > 0 && rank_ >= nodes.size()) { if (rank_ > 0 && rank_ >= std::ssize(nodes)) {
throw std::runtime_error( throw std::runtime_error(
"[ring] Rank cannot be larger than the size of the group"); "[ring] Rank cannot be larger than the size of the group");
} }
@@ -589,7 +589,7 @@ class RingGroup : public GroupImpl {
// Configure all sockets to use TCP no delay. // Configure all sockets to use TCP no delay.
int one = 1; int one = 1;
for (int i = 0; i < sockets_right_.size(); i++) { for (int64_t i = 0; i < std::ssize(sockets_right_); i++) {
setsockopt(sockets_right_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one)); setsockopt(sockets_right_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one));
setsockopt(sockets_left_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one)); setsockopt(sockets_left_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one));
} }
@@ -646,7 +646,8 @@ class RingGroup : public GroupImpl {
output, all_reduce<T, MinOp<T>>(input, output, stream, MinOp<T>())); output, all_reduce<T, MinOp<T>>(input, output, stream, MinOp<T>()));
} }
std::shared_ptr<GroupImpl> split(int color, int key = -1) override { std::shared_ptr<GroupImpl> split(int /* color */, int /* key */ = -1)
override {
throw std::runtime_error("[ring] Group split not supported."); throw std::runtime_error("[ring] Group split not supported.");
} }
@@ -658,15 +659,15 @@ class RingGroup : public GroupImpl {
nbytes = input.nbytes(), nbytes = input.nbytes(),
output_ptr = output.data<char>(), output_ptr = output.data<char>(),
this]() { this]() {
constexpr size_t min_send_size = 262144; constexpr int64_t min_send_size = 262144;
size_t n_gathers = std::max( int64_t n_gathers = std::max<int64_t>(
std::min( std::min<int64_t>(
sockets_right_.size() + sockets_left_.size(), sockets_right_.size() + sockets_left_.size(),
nbytes / min_send_size), nbytes / min_send_size),
size_t(1)); 1);
size_t bytes_per_gather = ceildiv(nbytes, n_gathers); int64_t bytes_per_gather = ceildiv(nbytes, n_gathers);
std::vector<std::future<void>> all_gathers; std::vector<std::future<void>> all_gathers;
for (int i = 0; i < n_gathers; i++) { for (int64_t i = 0; i < n_gathers; i++) {
auto offset = i * bytes_per_gather; auto offset = i * bytes_per_gather;
all_gathers.emplace_back(pool_.enqueue(std::bind( all_gathers.emplace_back(pool_.enqueue(std::bind(
&RingGroup::all_gather_impl, &RingGroup::all_gather_impl,
@@ -742,10 +743,14 @@ class RingGroup : public GroupImpl {
auto out_ptr = output.data<char>(); auto out_ptr = output.data<char>();
auto& encoder = cpu::get_command_encoder(stream); auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(output); encoder.set_output_array(output);
encoder.dispatch([in_ptr, out_ptr, size = input.size(), this, reduce_op]() { encoder.dispatch([in_ptr,
out_ptr,
size = static_cast<int64_t>(input.size()),
this,
reduce_op]() {
// If the input data cannot be split into size_ segments then copy it and // If the input data cannot be split into size_ segments then copy it and
// all reduce a local buffer prefilled with 0s. // all reduce a local buffer prefilled with 0s.
size_t nbytes = size * sizeof(T); int64_t nbytes = size * sizeof(T);
if (size < size_) { if (size < size_) {
// TODO: Maybe allocate dynamically so we don't have the constraint // TODO: Maybe allocate dynamically so we don't have the constraint
// below? // below?
@@ -778,16 +783,16 @@ class RingGroup : public GroupImpl {
// Split the all reduces so that each member has at least 1 buffer to // Split the all reduces so that each member has at least 1 buffer to
// send/recv per segment. // send/recv per segment.
constexpr size_t min_send_size = 262144; constexpr int64_t min_send_size = 262144;
size_t n_reduces = std::max( int64_t n_reduces = std::max<int64_t>(
std::min( std::min<int64_t>(
sockets_right_.size() + sockets_left_.size(), sockets_right_.size() + sockets_left_.size(),
nbytes / (size_ * min_send_size)), nbytes / (size_ * min_send_size)),
size_t(1)); 1);
size_t step = ceildiv(size, n_reduces); int64_t step = ceildiv(size, n_reduces);
std::vector<std::future<void>> all_sums; std::vector<std::future<void>> all_sums;
for (int i = 0; i < n_reduces; i++) { for (int64_t i = 0; i < n_reduces; i++) {
all_sums.emplace_back(pool_.enqueue(std::bind( all_sums.emplace_back(pool_.enqueue(std::bind(
&RingGroup::all_reduce_impl<T, ReduceOp>, &RingGroup::all_reduce_impl<T, ReduceOp>,
this, this,
@@ -810,7 +815,7 @@ class RingGroup : public GroupImpl {
void all_reduce_impl( void all_reduce_impl(
T* buffer, T* buffer,
T* data, T* data,
size_t data_size, int64_t data_size,
int socket_right, int socket_right,
int socket_left, int socket_left,
int direction, int direction,
@@ -821,10 +826,10 @@ class RingGroup : public GroupImpl {
// We split the data into `size_` segments of size `segment_size` and each // We split the data into `size_` segments of size `segment_size` and each
// of these in smaller segments of ALL_SUM_SIZE which we 'll call packets. // of these in smaller segments of ALL_SUM_SIZE which we 'll call packets.
size_t segment_size = ceildiv(data_size, size_); int64_t segment_size = ceildiv(data_size, size_);
size_t BUFFER_SIZE = std::max( int64_t BUFFER_SIZE = std::max<int64_t>(
size_t(32768), std::min(ALL_SUM_SIZE / sizeof(T), segment_size / 2)); 32768, std::min<int64_t>(ALL_SUM_SIZE / sizeof(T), segment_size / 2));
size_t n_packets = ceildiv(segment_size, BUFFER_SIZE); int64_t n_packets = ceildiv(segment_size, BUFFER_SIZE);
// Initial segments // Initial segments
int send_segment = rank_; int send_segment = rank_;
@@ -833,21 +838,21 @@ class RingGroup : public GroupImpl {
// Plan the whole reduce in terms of sends and recvs as indices in data. // Plan the whole reduce in terms of sends and recvs as indices in data.
// It makes the actual async send and recv a bit simpler to follow when // It makes the actual async send and recv a bit simpler to follow when
// there are less offset calculations around. // there are less offset calculations around.
std::vector<std::pair<size_t, size_t>> send_plan; std::vector<std::pair<int64_t, int64_t>> send_plan;
std::vector<std::pair<size_t, size_t>> recv_plan; std::vector<std::pair<int64_t, int64_t>> recv_plan;
// Two times the same send/recv operations, first scatter reduce and then // Two times the same send/recv operations, first scatter reduce and then
// gather. // gather.
for (int k = 0; k < 2; k++) { for (int k = 0; k < 2; k++) {
for (int i = 0; i < size_ - 1; i++) { for (int i = 0; i < size_ - 1; i++) {
size_t send_start = send_segment * segment_size; int64_t send_start = send_segment * segment_size;
size_t send_stop = int64_t send_stop =
std::min((send_segment + 1) * segment_size, data_size); std::min((send_segment + 1) * segment_size, data_size);
size_t recv_start = recv_segment * segment_size; int64_t recv_start = recv_segment * segment_size;
size_t recv_stop = int64_t recv_stop =
std::min((recv_segment + 1) * segment_size, data_size); std::min((recv_segment + 1) * segment_size, data_size);
for (size_t j = 0; j < n_packets; j++) { for (int64_t j = 0; j < n_packets; j++) {
send_plan.emplace_back( send_plan.emplace_back(
std::min(send_start + j * BUFFER_SIZE, send_stop), std::min(send_start + j * BUFFER_SIZE, send_stop),
std::min(send_start + (j + 1) * BUFFER_SIZE, send_stop)); std::min(send_start + (j + 1) * BUFFER_SIZE, send_stop));
@@ -864,18 +869,18 @@ class RingGroup : public GroupImpl {
// Running the plan is fairly simple, we keep a send and a recv in flight // Running the plan is fairly simple, we keep a send and a recv in flight
// while doing the summation. // while doing the summation.
T* recv_buffers[ALL_SUM_BUFFERS]; T* recv_buffers[ALL_SUM_BUFFERS];
for (int i = 0; i < ALL_SUM_BUFFERS; i++) { for (int64_t i = 0; i < ALL_SUM_BUFFERS; i++) {
recv_buffers[i] = buffer + i * BUFFER_SIZE; recv_buffers[i] = buffer + i * BUFFER_SIZE;
} }
std::future<void> sends[2], recvs[2]; std::future<void> sends[2], recvs[2];
int a = 0; int a = 0;
int b = (n_packets > 1) ? 1 : 0; int b = (n_packets > 1) ? 1 : 0;
for (int i = 0, j = -b; i < send_plan.size(); j++, i++) { for (int i = 0, j = -b; i < std::ssize(send_plan); j++, i++) {
sends[a] = comm_.send( sends[a] = comm_.send(
socket_send, socket_send,
data + send_plan[i].first, data + send_plan[i].first,
send_plan[i].second - send_plan[i].first); send_plan[i].second - send_plan[i].first);
if (2 * i < send_plan.size()) { if (2 * i < std::ssize(send_plan)) {
recvs[a] = comm_.recv( recvs[a] = comm_.recv(
socket_recv, socket_recv,
recv_buffers[i % ALL_SUM_BUFFERS], recv_buffers[i % ALL_SUM_BUFFERS],
@@ -890,7 +895,7 @@ class RingGroup : public GroupImpl {
if (j >= 0) { if (j >= 0) {
sends[b].wait(); sends[b].wait();
recvs[b].wait(); recvs[b].wait();
if (2 * j < send_plan.size()) { if (2 * j < std::ssize(send_plan)) {
reduce_op( reduce_op(
recv_buffers[j % ALL_SUM_BUFFERS], recv_buffers[j % ALL_SUM_BUFFERS],
data + recv_plan[j].first, data + recv_plan[j].first,
@@ -907,8 +912,8 @@ class RingGroup : public GroupImpl {
void all_gather_impl( void all_gather_impl(
const char* input, const char* input,
char* output, char* output,
size_t input_size, int64_t input_size,
size_t data_size, int64_t data_size,
int socket_right, int socket_right,
int socket_left, int socket_left,
int direction) { int direction) {
@@ -941,11 +946,11 @@ class RingGroup : public GroupImpl {
} }
void void
send(const std::vector<int>& sockets, const char* data, size_t data_size) { send(const std::vector<int>& sockets, const char* data, int64_t data_size) {
size_t segment_size = int64_t segment_size =
std::max(size_t(1024), ceildiv(data_size, sockets.size())); std::max<int64_t>(1024, ceildiv(data_size, std::ssize(sockets)));
std::vector<std::future<void>> sends; std::vector<std::future<void>> sends;
for (int i = 0; i < sockets.size(); i++) { for (int i = 0; i < std::ssize(sockets); i++) {
if (i * segment_size >= data_size) { if (i * segment_size >= data_size) {
break; break;
} }
@@ -959,11 +964,11 @@ class RingGroup : public GroupImpl {
} }
} }
void recv(const std::vector<int>& sockets, char* data, size_t data_size) { void recv(const std::vector<int>& sockets, char* data, int64_t data_size) {
size_t segment_size = int64_t segment_size =
std::max(size_t(1024), ceildiv(data_size, sockets.size())); std::max<int64_t>(1024, ceildiv(data_size, std::ssize(sockets)));
std::vector<std::future<void>> recvs; std::vector<std::future<void>> recvs;
for (int i = 0; i < sockets.size(); i++) { for (int i = 0; i < std::ssize(sockets); i++) {
if (i * segment_size >= data_size) { if (i * segment_size >= data_size) {
break; break;
} }