mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
WIP (distributed)
This commit is contained in:
@@ -27,7 +27,7 @@ std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
|
||||
}
|
||||
|
||||
std::vector<array> AllReduce::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& /* primals */,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>&) {
|
||||
switch (reduce_type_) {
|
||||
@@ -44,10 +44,10 @@ std::vector<array> AllReduce::jvp(
|
||||
}
|
||||
|
||||
std::vector<array> AllReduce::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& /* primals */,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>&,
|
||||
const std::vector<array>& outputs) {
|
||||
const std::vector<array>& /* outputs */) {
|
||||
return cotangents;
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ std::pair<std::vector<array>, std::vector<int>> AllGather::vmap(
|
||||
}
|
||||
|
||||
std::vector<array> AllGather::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& /* primals */,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>&) {
|
||||
return {all_gather(tangents[0], group(), stream())};
|
||||
|
||||
@@ -90,8 +90,8 @@
|
||||
|
||||
namespace mlx::core::distributed::ring {
|
||||
|
||||
constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024;
|
||||
constexpr const size_t ALL_SUM_BUFFERS = 2;
|
||||
constexpr const int64_t ALL_SUM_SIZE = 8 * 1024 * 1024;
|
||||
constexpr const int64_t ALL_SUM_BUFFERS = 2;
|
||||
constexpr const int CONN_ATTEMPTS = 5;
|
||||
constexpr const int CONN_WAIT = 1000;
|
||||
|
||||
@@ -141,27 +141,27 @@ class SocketThread {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::future<void> send(const T* buffer, size_t size) {
|
||||
std::future<void> send(const T* buffer, int64_t size) {
|
||||
return send_impl(reinterpret_cast<const char*>(buffer), size * sizeof(T));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::future<void> recv(T* buffer, size_t size) {
|
||||
std::future<void> recv(T* buffer, int64_t size) {
|
||||
return recv_impl(reinterpret_cast<char*>(buffer), size * sizeof(T));
|
||||
}
|
||||
|
||||
private:
|
||||
struct SocketTask {
|
||||
SocketTask(void* b, size_t s, std::promise<void>&& p)
|
||||
SocketTask(void* b, int64_t s, std::promise<void>&& p)
|
||||
: buffer(b), size(s), promise(std::move(p)) {}
|
||||
SocketTask(SocketTask&& t)
|
||||
: buffer(t.buffer), size(t.size), promise(std::move(t.promise)) {}
|
||||
void* buffer;
|
||||
size_t size;
|
||||
int64_t size;
|
||||
std::promise<void> promise;
|
||||
};
|
||||
|
||||
std::future<void> send_impl(const char* buffer, size_t size) {
|
||||
std::future<void> send_impl(const char* buffer, int64_t size) {
|
||||
std::promise<void> send_completed_promise;
|
||||
auto send_completed_future = send_completed_promise.get_future();
|
||||
if (size == 0) {
|
||||
@@ -178,7 +178,7 @@ class SocketThread {
|
||||
return send_completed_future;
|
||||
}
|
||||
|
||||
std::future<void> recv_impl(char* buffer, size_t size) {
|
||||
std::future<void> recv_impl(char* buffer, int64_t size) {
|
||||
std::promise<void> recv_completed_promise;
|
||||
auto recv_completed_future = recv_completed_promise.get_future();
|
||||
if (size == 0) {
|
||||
@@ -232,7 +232,7 @@ class SocketThread {
|
||||
|
||||
if (!recvs_.empty()) {
|
||||
auto& task = recvs_.front();
|
||||
ssize_t r = ::recv(fd_, task.buffer, task.size, 0);
|
||||
int64_t r = ::recv(fd_, task.buffer, task.size, 0);
|
||||
if (r > 0) {
|
||||
task.buffer = static_cast<char*>(task.buffer) + r;
|
||||
task.size -= r;
|
||||
@@ -246,7 +246,7 @@ class SocketThread {
|
||||
}
|
||||
if (!sends_.empty()) {
|
||||
auto& task = sends_.front();
|
||||
ssize_t r = ::send(fd_, task.buffer, task.size, 0);
|
||||
int64_t r = ::send(fd_, task.buffer, task.size, 0);
|
||||
if (r > 0) {
|
||||
task.buffer = static_cast<char*>(task.buffer) + r;
|
||||
task.size -= r;
|
||||
@@ -283,12 +283,12 @@ class CommunicationThreads {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::future<void> send(int socket, T* buffer, size_t size) {
|
||||
std::future<void> send(int socket, T* buffer, int64_t size) {
|
||||
return threads_.at(socket).send<T>(buffer, size);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::future<void> recv(int socket, T* buffer, size_t size) {
|
||||
std::future<void> recv(int socket, T* buffer, int64_t size) {
|
||||
return threads_.at(socket).recv<T>(buffer, size);
|
||||
}
|
||||
|
||||
@@ -505,7 +505,7 @@ std::vector<int> make_connections(
|
||||
}
|
||||
template <typename T>
|
||||
struct SumOp {
|
||||
void operator()(const T* input, T* output, size_t N) {
|
||||
void operator()(const T* input, T* output, int64_t N) {
|
||||
while (N-- > 0) {
|
||||
*output += *input;
|
||||
input++;
|
||||
@@ -516,7 +516,7 @@ struct SumOp {
|
||||
|
||||
template <typename T>
|
||||
struct MaxOp {
|
||||
void operator()(const T* input, T* output, size_t N) {
|
||||
void operator()(const T* input, T* output, int64_t N) {
|
||||
while (N-- > 0) {
|
||||
*output = std::max(*output, *input);
|
||||
input++;
|
||||
@@ -527,7 +527,7 @@ struct MaxOp {
|
||||
|
||||
template <typename T>
|
||||
struct MinOp {
|
||||
void operator()(const T* input, T* output, size_t N) {
|
||||
void operator()(const T* input, T* output, int64_t N) {
|
||||
while (N-- > 0) {
|
||||
*output = std::min(*output, *input);
|
||||
input++;
|
||||
@@ -542,7 +542,7 @@ class RingGroup : public GroupImpl {
|
||||
public:
|
||||
RingGroup(int rank, std::vector<std::vector<address_t>> nodes, bool verbose)
|
||||
: rank_(rank), verbose_(verbose), pool_(0) {
|
||||
if (rank_ > 0 && rank_ >= nodes.size()) {
|
||||
if (rank_ > 0 && rank_ >= std::ssize(nodes)) {
|
||||
throw std::runtime_error(
|
||||
"[ring] Rank cannot be larger than the size of the group");
|
||||
}
|
||||
@@ -589,7 +589,7 @@ class RingGroup : public GroupImpl {
|
||||
|
||||
// Configure all sockets to use TCP no delay.
|
||||
int one = 1;
|
||||
for (int i = 0; i < sockets_right_.size(); i++) {
|
||||
for (int64_t i = 0; i < std::ssize(sockets_right_); i++) {
|
||||
setsockopt(sockets_right_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one));
|
||||
setsockopt(sockets_left_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one));
|
||||
}
|
||||
@@ -646,7 +646,8 @@ class RingGroup : public GroupImpl {
|
||||
output, all_reduce<T, MinOp<T>>(input, output, stream, MinOp<T>()));
|
||||
}
|
||||
|
||||
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||
std::shared_ptr<GroupImpl> split(int /* color */, int /* key */ = -1)
|
||||
override {
|
||||
throw std::runtime_error("[ring] Group split not supported.");
|
||||
}
|
||||
|
||||
@@ -658,15 +659,15 @@ class RingGroup : public GroupImpl {
|
||||
nbytes = input.nbytes(),
|
||||
output_ptr = output.data<char>(),
|
||||
this]() {
|
||||
constexpr size_t min_send_size = 262144;
|
||||
size_t n_gathers = std::max(
|
||||
std::min(
|
||||
constexpr int64_t min_send_size = 262144;
|
||||
int64_t n_gathers = std::max<int64_t>(
|
||||
std::min<int64_t>(
|
||||
sockets_right_.size() + sockets_left_.size(),
|
||||
nbytes / min_send_size),
|
||||
size_t(1));
|
||||
size_t bytes_per_gather = ceildiv(nbytes, n_gathers);
|
||||
1);
|
||||
int64_t bytes_per_gather = ceildiv(nbytes, n_gathers);
|
||||
std::vector<std::future<void>> all_gathers;
|
||||
for (int i = 0; i < n_gathers; i++) {
|
||||
for (int64_t i = 0; i < n_gathers; i++) {
|
||||
auto offset = i * bytes_per_gather;
|
||||
all_gathers.emplace_back(pool_.enqueue(std::bind(
|
||||
&RingGroup::all_gather_impl,
|
||||
@@ -742,10 +743,14 @@ class RingGroup : public GroupImpl {
|
||||
auto out_ptr = output.data<char>();
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_output_array(output);
|
||||
encoder.dispatch([in_ptr, out_ptr, size = input.size(), this, reduce_op]() {
|
||||
encoder.dispatch([in_ptr,
|
||||
out_ptr,
|
||||
size = static_cast<int64_t>(input.size()),
|
||||
this,
|
||||
reduce_op]() {
|
||||
// If the input data cannot be split into size_ segments then copy it and
|
||||
// all reduce a local buffer prefilled with 0s.
|
||||
size_t nbytes = size * sizeof(T);
|
||||
int64_t nbytes = size * sizeof(T);
|
||||
if (size < size_) {
|
||||
// TODO: Maybe allocate dynamically so we don't have the constraint
|
||||
// below?
|
||||
@@ -778,16 +783,16 @@ class RingGroup : public GroupImpl {
|
||||
|
||||
// Split the all reduces so that each member has at least 1 buffer to
|
||||
// send/recv per segment.
|
||||
constexpr size_t min_send_size = 262144;
|
||||
size_t n_reduces = std::max(
|
||||
std::min(
|
||||
constexpr int64_t min_send_size = 262144;
|
||||
int64_t n_reduces = std::max<int64_t>(
|
||||
std::min<int64_t>(
|
||||
sockets_right_.size() + sockets_left_.size(),
|
||||
nbytes / (size_ * min_send_size)),
|
||||
size_t(1));
|
||||
size_t step = ceildiv(size, n_reduces);
|
||||
1);
|
||||
int64_t step = ceildiv(size, n_reduces);
|
||||
std::vector<std::future<void>> all_sums;
|
||||
|
||||
for (int i = 0; i < n_reduces; i++) {
|
||||
for (int64_t i = 0; i < n_reduces; i++) {
|
||||
all_sums.emplace_back(pool_.enqueue(std::bind(
|
||||
&RingGroup::all_reduce_impl<T, ReduceOp>,
|
||||
this,
|
||||
@@ -810,7 +815,7 @@ class RingGroup : public GroupImpl {
|
||||
void all_reduce_impl(
|
||||
T* buffer,
|
||||
T* data,
|
||||
size_t data_size,
|
||||
int64_t data_size,
|
||||
int socket_right,
|
||||
int socket_left,
|
||||
int direction,
|
||||
@@ -821,10 +826,10 @@ class RingGroup : public GroupImpl {
|
||||
|
||||
// We split the data into `size_` segments of size `segment_size` and each
|
||||
// of these in smaller segments of ALL_SUM_SIZE which we 'll call packets.
|
||||
size_t segment_size = ceildiv(data_size, size_);
|
||||
size_t BUFFER_SIZE = std::max(
|
||||
size_t(32768), std::min(ALL_SUM_SIZE / sizeof(T), segment_size / 2));
|
||||
size_t n_packets = ceildiv(segment_size, BUFFER_SIZE);
|
||||
int64_t segment_size = ceildiv(data_size, size_);
|
||||
int64_t BUFFER_SIZE = std::max<int64_t>(
|
||||
32768, std::min<int64_t>(ALL_SUM_SIZE / sizeof(T), segment_size / 2));
|
||||
int64_t n_packets = ceildiv(segment_size, BUFFER_SIZE);
|
||||
|
||||
// Initial segments
|
||||
int send_segment = rank_;
|
||||
@@ -833,21 +838,21 @@ class RingGroup : public GroupImpl {
|
||||
// Plan the whole reduce in terms of sends and recvs as indices in data.
|
||||
// It makes the actual async send and recv a bit simpler to follow when
|
||||
// there are less offset calculations around.
|
||||
std::vector<std::pair<size_t, size_t>> send_plan;
|
||||
std::vector<std::pair<size_t, size_t>> recv_plan;
|
||||
std::vector<std::pair<int64_t, int64_t>> send_plan;
|
||||
std::vector<std::pair<int64_t, int64_t>> recv_plan;
|
||||
|
||||
// Two times the same send/recv operations, first scatter reduce and then
|
||||
// gather.
|
||||
for (int k = 0; k < 2; k++) {
|
||||
for (int i = 0; i < size_ - 1; i++) {
|
||||
size_t send_start = send_segment * segment_size;
|
||||
size_t send_stop =
|
||||
int64_t send_start = send_segment * segment_size;
|
||||
int64_t send_stop =
|
||||
std::min((send_segment + 1) * segment_size, data_size);
|
||||
size_t recv_start = recv_segment * segment_size;
|
||||
size_t recv_stop =
|
||||
int64_t recv_start = recv_segment * segment_size;
|
||||
int64_t recv_stop =
|
||||
std::min((recv_segment + 1) * segment_size, data_size);
|
||||
|
||||
for (size_t j = 0; j < n_packets; j++) {
|
||||
for (int64_t j = 0; j < n_packets; j++) {
|
||||
send_plan.emplace_back(
|
||||
std::min(send_start + j * BUFFER_SIZE, send_stop),
|
||||
std::min(send_start + (j + 1) * BUFFER_SIZE, send_stop));
|
||||
@@ -864,18 +869,18 @@ class RingGroup : public GroupImpl {
|
||||
// Running the plan is fairly simple, we keep a send and a recv in flight
|
||||
// while doing the summation.
|
||||
T* recv_buffers[ALL_SUM_BUFFERS];
|
||||
for (int i = 0; i < ALL_SUM_BUFFERS; i++) {
|
||||
for (int64_t i = 0; i < ALL_SUM_BUFFERS; i++) {
|
||||
recv_buffers[i] = buffer + i * BUFFER_SIZE;
|
||||
}
|
||||
std::future<void> sends[2], recvs[2];
|
||||
int a = 0;
|
||||
int b = (n_packets > 1) ? 1 : 0;
|
||||
for (int i = 0, j = -b; i < send_plan.size(); j++, i++) {
|
||||
for (int i = 0, j = -b; i < std::ssize(send_plan); j++, i++) {
|
||||
sends[a] = comm_.send(
|
||||
socket_send,
|
||||
data + send_plan[i].first,
|
||||
send_plan[i].second - send_plan[i].first);
|
||||
if (2 * i < send_plan.size()) {
|
||||
if (2 * i < std::ssize(send_plan)) {
|
||||
recvs[a] = comm_.recv(
|
||||
socket_recv,
|
||||
recv_buffers[i % ALL_SUM_BUFFERS],
|
||||
@@ -890,7 +895,7 @@ class RingGroup : public GroupImpl {
|
||||
if (j >= 0) {
|
||||
sends[b].wait();
|
||||
recvs[b].wait();
|
||||
if (2 * j < send_plan.size()) {
|
||||
if (2 * j < std::ssize(send_plan)) {
|
||||
reduce_op(
|
||||
recv_buffers[j % ALL_SUM_BUFFERS],
|
||||
data + recv_plan[j].first,
|
||||
@@ -907,8 +912,8 @@ class RingGroup : public GroupImpl {
|
||||
void all_gather_impl(
|
||||
const char* input,
|
||||
char* output,
|
||||
size_t input_size,
|
||||
size_t data_size,
|
||||
int64_t input_size,
|
||||
int64_t data_size,
|
||||
int socket_right,
|
||||
int socket_left,
|
||||
int direction) {
|
||||
@@ -941,11 +946,11 @@ class RingGroup : public GroupImpl {
|
||||
}
|
||||
|
||||
void
|
||||
send(const std::vector<int>& sockets, const char* data, size_t data_size) {
|
||||
size_t segment_size =
|
||||
std::max(size_t(1024), ceildiv(data_size, sockets.size()));
|
||||
send(const std::vector<int>& sockets, const char* data, int64_t data_size) {
|
||||
int64_t segment_size =
|
||||
std::max<int64_t>(1024, ceildiv(data_size, std::ssize(sockets)));
|
||||
std::vector<std::future<void>> sends;
|
||||
for (int i = 0; i < sockets.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(sockets); i++) {
|
||||
if (i * segment_size >= data_size) {
|
||||
break;
|
||||
}
|
||||
@@ -959,11 +964,11 @@ class RingGroup : public GroupImpl {
|
||||
}
|
||||
}
|
||||
|
||||
void recv(const std::vector<int>& sockets, char* data, size_t data_size) {
|
||||
size_t segment_size =
|
||||
std::max(size_t(1024), ceildiv(data_size, sockets.size()));
|
||||
void recv(const std::vector<int>& sockets, char* data, int64_t data_size) {
|
||||
int64_t segment_size =
|
||||
std::max<int64_t>(1024, ceildiv(data_size, std::ssize(sockets)));
|
||||
std::vector<std::future<void>> recvs;
|
||||
for (int i = 0; i < sockets.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(sockets); i++) {
|
||||
if (i * segment_size >= data_size) {
|
||||
break;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user