mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
WIP (distributed)
This commit is contained in:
@@ -27,7 +27,7 @@ std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> AllReduce::jvp(
|
std::vector<array> AllReduce::jvp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& /* primals */,
|
||||||
const std::vector<array>& tangents,
|
const std::vector<array>& tangents,
|
||||||
const std::vector<int>&) {
|
const std::vector<int>&) {
|
||||||
switch (reduce_type_) {
|
switch (reduce_type_) {
|
||||||
@@ -44,10 +44,10 @@ std::vector<array> AllReduce::jvp(
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> AllReduce::vjp(
|
std::vector<array> AllReduce::vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& /* primals */,
|
||||||
const std::vector<array>& cotangents,
|
const std::vector<array>& cotangents,
|
||||||
const std::vector<int>&,
|
const std::vector<int>&,
|
||||||
const std::vector<array>& outputs) {
|
const std::vector<array>& /* outputs */) {
|
||||||
return cotangents;
|
return cotangents;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -58,7 +58,7 @@ std::pair<std::vector<array>, std::vector<int>> AllGather::vmap(
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> AllGather::jvp(
|
std::vector<array> AllGather::jvp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& /* primals */,
|
||||||
const std::vector<array>& tangents,
|
const std::vector<array>& tangents,
|
||||||
const std::vector<int>&) {
|
const std::vector<int>&) {
|
||||||
return {all_gather(tangents[0], group(), stream())};
|
return {all_gather(tangents[0], group(), stream())};
|
||||||
|
|||||||
@@ -90,8 +90,8 @@
|
|||||||
|
|
||||||
namespace mlx::core::distributed::ring {
|
namespace mlx::core::distributed::ring {
|
||||||
|
|
||||||
constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024;
|
constexpr const int64_t ALL_SUM_SIZE = 8 * 1024 * 1024;
|
||||||
constexpr const size_t ALL_SUM_BUFFERS = 2;
|
constexpr const int64_t ALL_SUM_BUFFERS = 2;
|
||||||
constexpr const int CONN_ATTEMPTS = 5;
|
constexpr const int CONN_ATTEMPTS = 5;
|
||||||
constexpr const int CONN_WAIT = 1000;
|
constexpr const int CONN_WAIT = 1000;
|
||||||
|
|
||||||
@@ -141,27 +141,27 @@ class SocketThread {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::future<void> send(const T* buffer, size_t size) {
|
std::future<void> send(const T* buffer, int64_t size) {
|
||||||
return send_impl(reinterpret_cast<const char*>(buffer), size * sizeof(T));
|
return send_impl(reinterpret_cast<const char*>(buffer), size * sizeof(T));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::future<void> recv(T* buffer, size_t size) {
|
std::future<void> recv(T* buffer, int64_t size) {
|
||||||
return recv_impl(reinterpret_cast<char*>(buffer), size * sizeof(T));
|
return recv_impl(reinterpret_cast<char*>(buffer), size * sizeof(T));
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
struct SocketTask {
|
struct SocketTask {
|
||||||
SocketTask(void* b, size_t s, std::promise<void>&& p)
|
SocketTask(void* b, int64_t s, std::promise<void>&& p)
|
||||||
: buffer(b), size(s), promise(std::move(p)) {}
|
: buffer(b), size(s), promise(std::move(p)) {}
|
||||||
SocketTask(SocketTask&& t)
|
SocketTask(SocketTask&& t)
|
||||||
: buffer(t.buffer), size(t.size), promise(std::move(t.promise)) {}
|
: buffer(t.buffer), size(t.size), promise(std::move(t.promise)) {}
|
||||||
void* buffer;
|
void* buffer;
|
||||||
size_t size;
|
int64_t size;
|
||||||
std::promise<void> promise;
|
std::promise<void> promise;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::future<void> send_impl(const char* buffer, size_t size) {
|
std::future<void> send_impl(const char* buffer, int64_t size) {
|
||||||
std::promise<void> send_completed_promise;
|
std::promise<void> send_completed_promise;
|
||||||
auto send_completed_future = send_completed_promise.get_future();
|
auto send_completed_future = send_completed_promise.get_future();
|
||||||
if (size == 0) {
|
if (size == 0) {
|
||||||
@@ -178,7 +178,7 @@ class SocketThread {
|
|||||||
return send_completed_future;
|
return send_completed_future;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::future<void> recv_impl(char* buffer, size_t size) {
|
std::future<void> recv_impl(char* buffer, int64_t size) {
|
||||||
std::promise<void> recv_completed_promise;
|
std::promise<void> recv_completed_promise;
|
||||||
auto recv_completed_future = recv_completed_promise.get_future();
|
auto recv_completed_future = recv_completed_promise.get_future();
|
||||||
if (size == 0) {
|
if (size == 0) {
|
||||||
@@ -232,7 +232,7 @@ class SocketThread {
|
|||||||
|
|
||||||
if (!recvs_.empty()) {
|
if (!recvs_.empty()) {
|
||||||
auto& task = recvs_.front();
|
auto& task = recvs_.front();
|
||||||
ssize_t r = ::recv(fd_, task.buffer, task.size, 0);
|
int64_t r = ::recv(fd_, task.buffer, task.size, 0);
|
||||||
if (r > 0) {
|
if (r > 0) {
|
||||||
task.buffer = static_cast<char*>(task.buffer) + r;
|
task.buffer = static_cast<char*>(task.buffer) + r;
|
||||||
task.size -= r;
|
task.size -= r;
|
||||||
@@ -246,7 +246,7 @@ class SocketThread {
|
|||||||
}
|
}
|
||||||
if (!sends_.empty()) {
|
if (!sends_.empty()) {
|
||||||
auto& task = sends_.front();
|
auto& task = sends_.front();
|
||||||
ssize_t r = ::send(fd_, task.buffer, task.size, 0);
|
int64_t r = ::send(fd_, task.buffer, task.size, 0);
|
||||||
if (r > 0) {
|
if (r > 0) {
|
||||||
task.buffer = static_cast<char*>(task.buffer) + r;
|
task.buffer = static_cast<char*>(task.buffer) + r;
|
||||||
task.size -= r;
|
task.size -= r;
|
||||||
@@ -283,12 +283,12 @@ class CommunicationThreads {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::future<void> send(int socket, T* buffer, size_t size) {
|
std::future<void> send(int socket, T* buffer, int64_t size) {
|
||||||
return threads_.at(socket).send<T>(buffer, size);
|
return threads_.at(socket).send<T>(buffer, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::future<void> recv(int socket, T* buffer, size_t size) {
|
std::future<void> recv(int socket, T* buffer, int64_t size) {
|
||||||
return threads_.at(socket).recv<T>(buffer, size);
|
return threads_.at(socket).recv<T>(buffer, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -505,7 +505,7 @@ std::vector<int> make_connections(
|
|||||||
}
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct SumOp {
|
struct SumOp {
|
||||||
void operator()(const T* input, T* output, size_t N) {
|
void operator()(const T* input, T* output, int64_t N) {
|
||||||
while (N-- > 0) {
|
while (N-- > 0) {
|
||||||
*output += *input;
|
*output += *input;
|
||||||
input++;
|
input++;
|
||||||
@@ -516,7 +516,7 @@ struct SumOp {
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct MaxOp {
|
struct MaxOp {
|
||||||
void operator()(const T* input, T* output, size_t N) {
|
void operator()(const T* input, T* output, int64_t N) {
|
||||||
while (N-- > 0) {
|
while (N-- > 0) {
|
||||||
*output = std::max(*output, *input);
|
*output = std::max(*output, *input);
|
||||||
input++;
|
input++;
|
||||||
@@ -527,7 +527,7 @@ struct MaxOp {
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct MinOp {
|
struct MinOp {
|
||||||
void operator()(const T* input, T* output, size_t N) {
|
void operator()(const T* input, T* output, int64_t N) {
|
||||||
while (N-- > 0) {
|
while (N-- > 0) {
|
||||||
*output = std::min(*output, *input);
|
*output = std::min(*output, *input);
|
||||||
input++;
|
input++;
|
||||||
@@ -542,7 +542,7 @@ class RingGroup : public GroupImpl {
|
|||||||
public:
|
public:
|
||||||
RingGroup(int rank, std::vector<std::vector<address_t>> nodes, bool verbose)
|
RingGroup(int rank, std::vector<std::vector<address_t>> nodes, bool verbose)
|
||||||
: rank_(rank), verbose_(verbose), pool_(0) {
|
: rank_(rank), verbose_(verbose), pool_(0) {
|
||||||
if (rank_ > 0 && rank_ >= nodes.size()) {
|
if (rank_ > 0 && rank_ >= std::ssize(nodes)) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[ring] Rank cannot be larger than the size of the group");
|
"[ring] Rank cannot be larger than the size of the group");
|
||||||
}
|
}
|
||||||
@@ -589,7 +589,7 @@ class RingGroup : public GroupImpl {
|
|||||||
|
|
||||||
// Configure all sockets to use TCP no delay.
|
// Configure all sockets to use TCP no delay.
|
||||||
int one = 1;
|
int one = 1;
|
||||||
for (int i = 0; i < sockets_right_.size(); i++) {
|
for (int64_t i = 0; i < std::ssize(sockets_right_); i++) {
|
||||||
setsockopt(sockets_right_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one));
|
setsockopt(sockets_right_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one));
|
||||||
setsockopt(sockets_left_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one));
|
setsockopt(sockets_left_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one));
|
||||||
}
|
}
|
||||||
@@ -646,7 +646,8 @@ class RingGroup : public GroupImpl {
|
|||||||
output, all_reduce<T, MinOp<T>>(input, output, stream, MinOp<T>()));
|
output, all_reduce<T, MinOp<T>>(input, output, stream, MinOp<T>()));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
std::shared_ptr<GroupImpl> split(int /* color */, int /* key */ = -1)
|
||||||
|
override {
|
||||||
throw std::runtime_error("[ring] Group split not supported.");
|
throw std::runtime_error("[ring] Group split not supported.");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -658,15 +659,15 @@ class RingGroup : public GroupImpl {
|
|||||||
nbytes = input.nbytes(),
|
nbytes = input.nbytes(),
|
||||||
output_ptr = output.data<char>(),
|
output_ptr = output.data<char>(),
|
||||||
this]() {
|
this]() {
|
||||||
constexpr size_t min_send_size = 262144;
|
constexpr int64_t min_send_size = 262144;
|
||||||
size_t n_gathers = std::max(
|
int64_t n_gathers = std::max<int64_t>(
|
||||||
std::min(
|
std::min<int64_t>(
|
||||||
sockets_right_.size() + sockets_left_.size(),
|
sockets_right_.size() + sockets_left_.size(),
|
||||||
nbytes / min_send_size),
|
nbytes / min_send_size),
|
||||||
size_t(1));
|
1);
|
||||||
size_t bytes_per_gather = ceildiv(nbytes, n_gathers);
|
int64_t bytes_per_gather = ceildiv(nbytes, n_gathers);
|
||||||
std::vector<std::future<void>> all_gathers;
|
std::vector<std::future<void>> all_gathers;
|
||||||
for (int i = 0; i < n_gathers; i++) {
|
for (int64_t i = 0; i < n_gathers; i++) {
|
||||||
auto offset = i * bytes_per_gather;
|
auto offset = i * bytes_per_gather;
|
||||||
all_gathers.emplace_back(pool_.enqueue(std::bind(
|
all_gathers.emplace_back(pool_.enqueue(std::bind(
|
||||||
&RingGroup::all_gather_impl,
|
&RingGroup::all_gather_impl,
|
||||||
@@ -742,10 +743,14 @@ class RingGroup : public GroupImpl {
|
|||||||
auto out_ptr = output.data<char>();
|
auto out_ptr = output.data<char>();
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
encoder.set_output_array(output);
|
encoder.set_output_array(output);
|
||||||
encoder.dispatch([in_ptr, out_ptr, size = input.size(), this, reduce_op]() {
|
encoder.dispatch([in_ptr,
|
||||||
|
out_ptr,
|
||||||
|
size = static_cast<int64_t>(input.size()),
|
||||||
|
this,
|
||||||
|
reduce_op]() {
|
||||||
// If the input data cannot be split into size_ segments then copy it and
|
// If the input data cannot be split into size_ segments then copy it and
|
||||||
// all reduce a local buffer prefilled with 0s.
|
// all reduce a local buffer prefilled with 0s.
|
||||||
size_t nbytes = size * sizeof(T);
|
int64_t nbytes = size * sizeof(T);
|
||||||
if (size < size_) {
|
if (size < size_) {
|
||||||
// TODO: Maybe allocate dynamically so we don't have the constraint
|
// TODO: Maybe allocate dynamically so we don't have the constraint
|
||||||
// below?
|
// below?
|
||||||
@@ -778,16 +783,16 @@ class RingGroup : public GroupImpl {
|
|||||||
|
|
||||||
// Split the all reduces so that each member has at least 1 buffer to
|
// Split the all reduces so that each member has at least 1 buffer to
|
||||||
// send/recv per segment.
|
// send/recv per segment.
|
||||||
constexpr size_t min_send_size = 262144;
|
constexpr int64_t min_send_size = 262144;
|
||||||
size_t n_reduces = std::max(
|
int64_t n_reduces = std::max<int64_t>(
|
||||||
std::min(
|
std::min<int64_t>(
|
||||||
sockets_right_.size() + sockets_left_.size(),
|
sockets_right_.size() + sockets_left_.size(),
|
||||||
nbytes / (size_ * min_send_size)),
|
nbytes / (size_ * min_send_size)),
|
||||||
size_t(1));
|
1);
|
||||||
size_t step = ceildiv(size, n_reduces);
|
int64_t step = ceildiv(size, n_reduces);
|
||||||
std::vector<std::future<void>> all_sums;
|
std::vector<std::future<void>> all_sums;
|
||||||
|
|
||||||
for (int i = 0; i < n_reduces; i++) {
|
for (int64_t i = 0; i < n_reduces; i++) {
|
||||||
all_sums.emplace_back(pool_.enqueue(std::bind(
|
all_sums.emplace_back(pool_.enqueue(std::bind(
|
||||||
&RingGroup::all_reduce_impl<T, ReduceOp>,
|
&RingGroup::all_reduce_impl<T, ReduceOp>,
|
||||||
this,
|
this,
|
||||||
@@ -810,7 +815,7 @@ class RingGroup : public GroupImpl {
|
|||||||
void all_reduce_impl(
|
void all_reduce_impl(
|
||||||
T* buffer,
|
T* buffer,
|
||||||
T* data,
|
T* data,
|
||||||
size_t data_size,
|
int64_t data_size,
|
||||||
int socket_right,
|
int socket_right,
|
||||||
int socket_left,
|
int socket_left,
|
||||||
int direction,
|
int direction,
|
||||||
@@ -821,10 +826,10 @@ class RingGroup : public GroupImpl {
|
|||||||
|
|
||||||
// We split the data into `size_` segments of size `segment_size` and each
|
// We split the data into `size_` segments of size `segment_size` and each
|
||||||
// of these in smaller segments of ALL_SUM_SIZE which we 'll call packets.
|
// of these in smaller segments of ALL_SUM_SIZE which we 'll call packets.
|
||||||
size_t segment_size = ceildiv(data_size, size_);
|
int64_t segment_size = ceildiv(data_size, size_);
|
||||||
size_t BUFFER_SIZE = std::max(
|
int64_t BUFFER_SIZE = std::max<int64_t>(
|
||||||
size_t(32768), std::min(ALL_SUM_SIZE / sizeof(T), segment_size / 2));
|
32768, std::min<int64_t>(ALL_SUM_SIZE / sizeof(T), segment_size / 2));
|
||||||
size_t n_packets = ceildiv(segment_size, BUFFER_SIZE);
|
int64_t n_packets = ceildiv(segment_size, BUFFER_SIZE);
|
||||||
|
|
||||||
// Initial segments
|
// Initial segments
|
||||||
int send_segment = rank_;
|
int send_segment = rank_;
|
||||||
@@ -833,21 +838,21 @@ class RingGroup : public GroupImpl {
|
|||||||
// Plan the whole reduce in terms of sends and recvs as indices in data.
|
// Plan the whole reduce in terms of sends and recvs as indices in data.
|
||||||
// It makes the actual async send and recv a bit simpler to follow when
|
// It makes the actual async send and recv a bit simpler to follow when
|
||||||
// there are less offset calculations around.
|
// there are less offset calculations around.
|
||||||
std::vector<std::pair<size_t, size_t>> send_plan;
|
std::vector<std::pair<int64_t, int64_t>> send_plan;
|
||||||
std::vector<std::pair<size_t, size_t>> recv_plan;
|
std::vector<std::pair<int64_t, int64_t>> recv_plan;
|
||||||
|
|
||||||
// Two times the same send/recv operations, first scatter reduce and then
|
// Two times the same send/recv operations, first scatter reduce and then
|
||||||
// gather.
|
// gather.
|
||||||
for (int k = 0; k < 2; k++) {
|
for (int k = 0; k < 2; k++) {
|
||||||
for (int i = 0; i < size_ - 1; i++) {
|
for (int i = 0; i < size_ - 1; i++) {
|
||||||
size_t send_start = send_segment * segment_size;
|
int64_t send_start = send_segment * segment_size;
|
||||||
size_t send_stop =
|
int64_t send_stop =
|
||||||
std::min((send_segment + 1) * segment_size, data_size);
|
std::min((send_segment + 1) * segment_size, data_size);
|
||||||
size_t recv_start = recv_segment * segment_size;
|
int64_t recv_start = recv_segment * segment_size;
|
||||||
size_t recv_stop =
|
int64_t recv_stop =
|
||||||
std::min((recv_segment + 1) * segment_size, data_size);
|
std::min((recv_segment + 1) * segment_size, data_size);
|
||||||
|
|
||||||
for (size_t j = 0; j < n_packets; j++) {
|
for (int64_t j = 0; j < n_packets; j++) {
|
||||||
send_plan.emplace_back(
|
send_plan.emplace_back(
|
||||||
std::min(send_start + j * BUFFER_SIZE, send_stop),
|
std::min(send_start + j * BUFFER_SIZE, send_stop),
|
||||||
std::min(send_start + (j + 1) * BUFFER_SIZE, send_stop));
|
std::min(send_start + (j + 1) * BUFFER_SIZE, send_stop));
|
||||||
@@ -864,18 +869,18 @@ class RingGroup : public GroupImpl {
|
|||||||
// Running the plan is fairly simple, we keep a send and a recv in flight
|
// Running the plan is fairly simple, we keep a send and a recv in flight
|
||||||
// while doing the summation.
|
// while doing the summation.
|
||||||
T* recv_buffers[ALL_SUM_BUFFERS];
|
T* recv_buffers[ALL_SUM_BUFFERS];
|
||||||
for (int i = 0; i < ALL_SUM_BUFFERS; i++) {
|
for (int64_t i = 0; i < ALL_SUM_BUFFERS; i++) {
|
||||||
recv_buffers[i] = buffer + i * BUFFER_SIZE;
|
recv_buffers[i] = buffer + i * BUFFER_SIZE;
|
||||||
}
|
}
|
||||||
std::future<void> sends[2], recvs[2];
|
std::future<void> sends[2], recvs[2];
|
||||||
int a = 0;
|
int a = 0;
|
||||||
int b = (n_packets > 1) ? 1 : 0;
|
int b = (n_packets > 1) ? 1 : 0;
|
||||||
for (int i = 0, j = -b; i < send_plan.size(); j++, i++) {
|
for (int i = 0, j = -b; i < std::ssize(send_plan); j++, i++) {
|
||||||
sends[a] = comm_.send(
|
sends[a] = comm_.send(
|
||||||
socket_send,
|
socket_send,
|
||||||
data + send_plan[i].first,
|
data + send_plan[i].first,
|
||||||
send_plan[i].second - send_plan[i].first);
|
send_plan[i].second - send_plan[i].first);
|
||||||
if (2 * i < send_plan.size()) {
|
if (2 * i < std::ssize(send_plan)) {
|
||||||
recvs[a] = comm_.recv(
|
recvs[a] = comm_.recv(
|
||||||
socket_recv,
|
socket_recv,
|
||||||
recv_buffers[i % ALL_SUM_BUFFERS],
|
recv_buffers[i % ALL_SUM_BUFFERS],
|
||||||
@@ -890,7 +895,7 @@ class RingGroup : public GroupImpl {
|
|||||||
if (j >= 0) {
|
if (j >= 0) {
|
||||||
sends[b].wait();
|
sends[b].wait();
|
||||||
recvs[b].wait();
|
recvs[b].wait();
|
||||||
if (2 * j < send_plan.size()) {
|
if (2 * j < std::ssize(send_plan)) {
|
||||||
reduce_op(
|
reduce_op(
|
||||||
recv_buffers[j % ALL_SUM_BUFFERS],
|
recv_buffers[j % ALL_SUM_BUFFERS],
|
||||||
data + recv_plan[j].first,
|
data + recv_plan[j].first,
|
||||||
@@ -907,8 +912,8 @@ class RingGroup : public GroupImpl {
|
|||||||
void all_gather_impl(
|
void all_gather_impl(
|
||||||
const char* input,
|
const char* input,
|
||||||
char* output,
|
char* output,
|
||||||
size_t input_size,
|
int64_t input_size,
|
||||||
size_t data_size,
|
int64_t data_size,
|
||||||
int socket_right,
|
int socket_right,
|
||||||
int socket_left,
|
int socket_left,
|
||||||
int direction) {
|
int direction) {
|
||||||
@@ -941,11 +946,11 @@ class RingGroup : public GroupImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
send(const std::vector<int>& sockets, const char* data, size_t data_size) {
|
send(const std::vector<int>& sockets, const char* data, int64_t data_size) {
|
||||||
size_t segment_size =
|
int64_t segment_size =
|
||||||
std::max(size_t(1024), ceildiv(data_size, sockets.size()));
|
std::max<int64_t>(1024, ceildiv(data_size, std::ssize(sockets)));
|
||||||
std::vector<std::future<void>> sends;
|
std::vector<std::future<void>> sends;
|
||||||
for (int i = 0; i < sockets.size(); i++) {
|
for (int i = 0; i < std::ssize(sockets); i++) {
|
||||||
if (i * segment_size >= data_size) {
|
if (i * segment_size >= data_size) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -959,11 +964,11 @@ class RingGroup : public GroupImpl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void recv(const std::vector<int>& sockets, char* data, size_t data_size) {
|
void recv(const std::vector<int>& sockets, char* data, int64_t data_size) {
|
||||||
size_t segment_size =
|
int64_t segment_size =
|
||||||
std::max(size_t(1024), ceildiv(data_size, sockets.size()));
|
std::max<int64_t>(1024, ceildiv(data_size, std::ssize(sockets)));
|
||||||
std::vector<std::future<void>> recvs;
|
std::vector<std::future<void>> recvs;
|
||||||
for (int i = 0; i < sockets.size(); i++) {
|
for (int i = 0; i < std::ssize(sockets); i++) {
|
||||||
if (i * segment_size >= data_size) {
|
if (i * segment_size >= data_size) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user