mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
redesign for faster cpu/gpu synch (#1869)
* redesign for faster cpu/gpu synch * load + more async CPU * use command encoder API and move more ops to use it * make fence back-end generic + CPU only fence * faster build * fix async eval * fixes + handle temporaries * fix / improve cpu conv * remove unused status, fix siblings * fix extensions * fix * fix no cpu build * format * comments * fix perf regression, remove unecessary abort * fix events, task limit cpu * fix waiting * fix donation / temporaries in normalization
This commit is contained in:
@@ -6,31 +6,25 @@
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/distributed/mpi/mpi.h"
|
||||
#include "mlx/distributed/ring/ring.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
namespace detail {
|
||||
|
||||
Stream communication_stream() {
|
||||
static Stream comm_stream = new_stream(Device::cpu);
|
||||
return comm_stream;
|
||||
void all_sum(Group group, const array& input, array& output, Stream stream) {
|
||||
group.raw_group()->all_sum(input, output, stream);
|
||||
}
|
||||
|
||||
void all_sum(Group group, const array& input, array& output) {
|
||||
group.raw_group()->all_sum(input, output);
|
||||
void all_gather(Group group, const array& input, array& output, Stream stream) {
|
||||
group.raw_group()->all_gather(input, output, stream);
|
||||
}
|
||||
|
||||
void all_gather(Group group, const array& input, array& output) {
|
||||
group.raw_group()->all_gather(input, output);
|
||||
void send(Group group, const array& input, int dst, Stream stream) {
|
||||
group.raw_group()->send(input, dst, stream);
|
||||
}
|
||||
|
||||
void send(Group group, const array& input, int dst) {
|
||||
group.raw_group()->send(input, dst);
|
||||
}
|
||||
|
||||
void recv(Group group, array& out, int src) {
|
||||
group.raw_group()->recv(out, src);
|
||||
void recv(Group group, array& out, int src, Stream stream) {
|
||||
group.raw_group()->recv(out, src, stream);
|
||||
}
|
||||
|
||||
class EmptyGroup : public GroupImpl {
|
||||
@@ -47,19 +41,19 @@ class EmptyGroup : public GroupImpl {
|
||||
throw std::runtime_error("Cannot split the distributed group further.");
|
||||
}
|
||||
|
||||
void all_sum(const array& input, array& output) override {
|
||||
void all_sum(const array&, array&, Stream) override {
|
||||
throw std::runtime_error(
|
||||
"Communication not implemented in an empty distributed group.");
|
||||
}
|
||||
void all_gather(const array& input, array& output) override {
|
||||
void all_gather(const array&, array&, Stream) override {
|
||||
throw std::runtime_error(
|
||||
"Communication not implemented in an empty distributed group.");
|
||||
}
|
||||
void send(const array& input, int dst) override {
|
||||
void send(const array&, int, Stream) override {
|
||||
throw std::runtime_error(
|
||||
"Communication not implemented in an empty distributed group.");
|
||||
}
|
||||
void recv(array& out, int src) override {
|
||||
void recv(array&, int, Stream) override {
|
||||
throw std::runtime_error(
|
||||
"Communication not implemented in an empty distributed group.");
|
||||
}
|
||||
@@ -122,10 +116,6 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
||||
backends.insert({"any", group});
|
||||
}
|
||||
backends.insert({std::move(bk_), group});
|
||||
|
||||
// Ensure the communication stream is alive before
|
||||
// the graph is evaluated
|
||||
detail::communication_stream();
|
||||
return Group(group);
|
||||
}
|
||||
|
||||
|
||||
@@ -17,25 +17,22 @@ class GroupImpl {
|
||||
virtual int size() = 0;
|
||||
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;
|
||||
|
||||
virtual void all_sum(const array& input, array& output) = 0;
|
||||
virtual void all_gather(const array& input, array& output) = 0;
|
||||
virtual void send(const array& input, int dst) = 0;
|
||||
virtual void recv(array& out, int src) = 0;
|
||||
virtual void all_sum(const array& input, array& output, Stream stream) = 0;
|
||||
virtual void all_gather(const array& input, array& output, Stream stream) = 0;
|
||||
virtual void send(const array& input, int dst, Stream stream) = 0;
|
||||
virtual void recv(array& out, int src, Stream stream) = 0;
|
||||
};
|
||||
|
||||
/* Return the communication stream. */
|
||||
Stream communication_stream();
|
||||
|
||||
/* Perform an all reduce sum operation */
|
||||
void all_sum(Group group, const array& input, array& output);
|
||||
void all_sum(Group group, const array& input, array& output, Stream stream);
|
||||
|
||||
/* Perform an all gather operation */
|
||||
void all_gather(Group group, const array& input, array& output);
|
||||
void all_gather(Group group, const array& input, array& output, Stream stream);
|
||||
|
||||
/** Send an array to the dst rank */
|
||||
void send(Group group, const array& input, int dst);
|
||||
void send(Group group, const array& input, int dst, Stream stream);
|
||||
|
||||
/** Recv an array from the src rank */
|
||||
void recv(Group group, array& out, int src);
|
||||
void recv(Group group, array& out, int src, Stream stream);
|
||||
|
||||
} // namespace mlx::core::distributed::detail
|
||||
|
||||
@@ -3,11 +3,10 @@
|
||||
#include <dlfcn.h>
|
||||
#include <mpi.h>
|
||||
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/distributed/mpi/mpi.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
#define LOAD_SYMBOL(symbol, variable) \
|
||||
{ \
|
||||
@@ -25,16 +24,6 @@ using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||
|
||||
namespace {
|
||||
|
||||
array ensure_row_contiguous(const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
return arr_copy;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void simple_sum(
|
||||
void* input,
|
||||
@@ -281,9 +270,12 @@ class MPIGroup : public GroupImpl {
|
||||
return std::make_shared<MPIGroup>(new_comm, false);
|
||||
}
|
||||
|
||||
void all_sum(const array& input_, array& output) override {
|
||||
array input = ensure_row_contiguous(input_);
|
||||
mpi().all_reduce(
|
||||
void all_sum(const array& input, array& output, Stream stream) override {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(input);
|
||||
encoder.set_output_array(output);
|
||||
encoder.dispatch(
|
||||
mpi().all_reduce,
|
||||
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE
|
||||
: input.data<void>(),
|
||||
output.data<void>(),
|
||||
@@ -293,9 +285,12 @@ class MPIGroup : public GroupImpl {
|
||||
comm_);
|
||||
}
|
||||
|
||||
void all_gather(const array& input_, array& output) override {
|
||||
array input = ensure_row_contiguous(input_);
|
||||
mpi().all_gather(
|
||||
void all_gather(const array& input, array& output, Stream stream) override {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(input);
|
||||
encoder.set_output_array(output);
|
||||
encoder.dispatch(
|
||||
mpi().all_gather,
|
||||
input.data<void>(),
|
||||
input.size(),
|
||||
mpi().datatype(input),
|
||||
@@ -305,22 +300,30 @@ class MPIGroup : public GroupImpl {
|
||||
comm_);
|
||||
}
|
||||
|
||||
void send(const array& input_, int dst) override {
|
||||
array input = ensure_row_contiguous(input_);
|
||||
mpi().send(
|
||||
input.data<void>(), input.size(), mpi().datatype(input), dst, 0, comm_);
|
||||
void send(const array& input, int dst, Stream stream) override {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(input);
|
||||
encoder.dispatch(
|
||||
mpi().send,
|
||||
input.data<void>(),
|
||||
input.size(),
|
||||
mpi().datatype(input),
|
||||
dst,
|
||||
0,
|
||||
comm_);
|
||||
}
|
||||
|
||||
void recv(array& out, int src) override {
|
||||
MPI_Status status;
|
||||
mpi().recv(
|
||||
out.data<void>(),
|
||||
out.size(),
|
||||
mpi().datatype(out),
|
||||
src,
|
||||
MPI_ANY_TAG,
|
||||
comm_,
|
||||
&status);
|
||||
void recv(array& out, int src, Stream stream) override {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out_ptr = out.data<void>(),
|
||||
out_size = out.size(),
|
||||
out_type = mpi().datatype(out),
|
||||
src,
|
||||
comm = comm_]() {
|
||||
MPI_Status status;
|
||||
mpi().recv(out_ptr, out_size, out_type, src, MPI_ANY_TAG, comm, &status);
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
@@ -28,11 +28,11 @@ array all_sum(
|
||||
if (group.size() == 1) {
|
||||
return x;
|
||||
}
|
||||
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_shared<AllReduce>(to_stream(s), group, AllReduce::Sum),
|
||||
std::make_shared<AllReduce>(
|
||||
to_stream(s, Device::cpu), group, AllReduce::Sum),
|
||||
{x});
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ array all_gather(
|
||||
return array(
|
||||
std::move(result_shape),
|
||||
x.dtype(),
|
||||
std::make_shared<AllGather>(to_stream(s), group),
|
||||
std::make_shared<AllGather>(to_stream(s, Device::cpu), group),
|
||||
{x});
|
||||
}
|
||||
|
||||
@@ -80,7 +80,7 @@ array send(
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_shared<Send>(to_stream(s), group, dst),
|
||||
std::make_shared<Send>(to_stream(s, Device::cpu), group, dst),
|
||||
{x});
|
||||
}
|
||||
|
||||
@@ -105,7 +105,7 @@ array recv(
|
||||
return array(
|
||||
std::move(shape),
|
||||
std::move(dtype),
|
||||
std::make_shared<Recv>(to_stream(s), group, src),
|
||||
std::make_shared<Recv>(to_stream(s, Device::cpu), group, src),
|
||||
std::vector<array>{});
|
||||
}
|
||||
|
||||
|
||||
@@ -3,34 +3,12 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/ops.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
void AllReduce::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
if (inputs[0].is_donatable()) {
|
||||
outputs[0].copy_shared_buffer(inputs[0]);
|
||||
} else {
|
||||
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
|
||||
}
|
||||
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
distributed::detail::all_sum(group(), inputs[0], outputs[0]);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
@@ -62,17 +40,6 @@ std::vector<array> AllReduce::vjp(
|
||||
return cotangents;
|
||||
}
|
||||
|
||||
void AllGather::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
|
||||
|
||||
distributed::detail::all_gather(group(), inputs[0], outputs[0]);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> AllGather::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
@@ -99,30 +66,10 @@ std::vector<array> AllGather::vjp(
|
||||
return {slice(cotangents[0], starts, stops)};
|
||||
}
|
||||
|
||||
void Send::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
distributed::detail::send(group(), inputs[0], dst_);
|
||||
move_or_copy(inputs[0], outputs[0]);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Send::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
return {{send(inputs[0], dst_, group(), stream())}, axes};
|
||||
}
|
||||
|
||||
void Recv::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 0);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
|
||||
distributed::detail::recv(group(), outputs[0], src_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
|
||||
#include <json.hpp>
|
||||
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/threadpool.h"
|
||||
@@ -140,8 +140,8 @@ class SocketThread {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::future<void> send(T* buffer, size_t size) {
|
||||
return send_impl(reinterpret_cast<char*>(buffer), size * sizeof(T));
|
||||
std::future<void> send(const T* buffer, size_t size) {
|
||||
return send_impl(reinterpret_cast<const char*>(buffer), size * sizeof(T));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -160,7 +160,7 @@ class SocketThread {
|
||||
std::promise<void> promise;
|
||||
};
|
||||
|
||||
std::future<void> send_impl(char* buffer, size_t size) {
|
||||
std::future<void> send_impl(const char* buffer, size_t size) {
|
||||
std::promise<void> send_completed_promise;
|
||||
auto send_completed_future = send_completed_promise.get_future();
|
||||
if (size == 0) {
|
||||
@@ -170,8 +170,8 @@ class SocketThread {
|
||||
|
||||
{
|
||||
std::unique_lock lock(queue_mutex_);
|
||||
sends_.emplace_back(
|
||||
SocketTask(buffer, size, std::move(send_completed_promise)));
|
||||
sends_.emplace_back(SocketTask(
|
||||
const_cast<char*>(buffer), size, std::move(send_completed_promise)));
|
||||
}
|
||||
condition_.notify_one();
|
||||
return send_completed_future;
|
||||
@@ -503,16 +503,6 @@ std::vector<int> make_connections(
|
||||
return sockets;
|
||||
}
|
||||
|
||||
array ensure_row_contiguous(const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
return arr_copy;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void sum_inplace(const T* input, T* output, size_t N) {
|
||||
while (N-- > 0) {
|
||||
@@ -613,117 +603,131 @@ class RingGroup : public GroupImpl {
|
||||
return size_;
|
||||
}
|
||||
|
||||
void all_sum(const array& input_, array& output) override {
|
||||
SWITCH_TYPE(output, all_sum<T>(input_, output));
|
||||
void all_sum(const array& input_, array& output, Stream stream) override {
|
||||
SWITCH_TYPE(output, all_sum<T>(input_, output, stream));
|
||||
}
|
||||
|
||||
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||
throw std::runtime_error("[ring] Group split not supported.");
|
||||
}
|
||||
|
||||
void all_gather(const array& input, array& output) override {
|
||||
void all_gather(const array& input, array& output, Stream stream) override {
|
||||
throw std::runtime_error("[ring] All gather not supported.");
|
||||
}
|
||||
|
||||
void send(const array& input_, int dst) override {
|
||||
// Make sure that the input is row contiguous
|
||||
array input = ensure_row_contiguous(input_);
|
||||
|
||||
int right = (rank_ + 1) % size_;
|
||||
int left = (rank_ + size_ - 1) % size_;
|
||||
if (dst == right) {
|
||||
send(sockets_right_, input.data<char>(), input.nbytes());
|
||||
} else if (dst == left) {
|
||||
send(sockets_left_, input.data<char>(), input.nbytes());
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Send only supported to direct neighbors "
|
||||
<< "but tried to send to " << dst << " from " << rank_ << std::endl;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
void send(const array& input, int dst, Stream stream) override {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(input);
|
||||
encoder.dispatch(
|
||||
[input_ptr = input.data<char>(), nbytes = input.nbytes(), dst, this]() {
|
||||
int right = (rank_ + 1) % size_;
|
||||
int left = (rank_ + size_ - 1) % size_;
|
||||
if (dst == right) {
|
||||
send(sockets_right_, input_ptr, nbytes);
|
||||
} else if (dst == left) {
|
||||
send(sockets_left_, input_ptr, nbytes);
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Send only supported to direct neighbors "
|
||||
<< "but tried to send to " << dst << " from " << rank_
|
||||
<< std::endl;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void recv(array& out, int src) override {
|
||||
// NOTE: We 'll check the sockets with the opposite order of send so that
|
||||
// they work even with 2 nodes where left and right is the same
|
||||
// neighbor.
|
||||
int right = (rank_ + 1) % size_;
|
||||
int left = (rank_ + size_ - 1) % size_;
|
||||
if (src == left) {
|
||||
recv(sockets_left_, out.data<char>(), out.nbytes());
|
||||
} else if (src == right) {
|
||||
recv(sockets_right_, out.data<char>(), out.nbytes());
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Recv only supported from direct neighbors "
|
||||
<< "but tried to recv from " << src << " to " << rank_ << std::endl;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
void recv(array& out, int src, Stream stream) override {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch(
|
||||
[out_ptr = out.data<char>(), nbytes = out.nbytes(), src, this]() {
|
||||
// NOTE: We 'll check the sockets with the opposite order of send so
|
||||
// that
|
||||
// they work even with 2 nodes where left and right is the same
|
||||
// neighbor.
|
||||
int right = (rank_ + 1) % size_;
|
||||
int left = (rank_ + size_ - 1) % size_;
|
||||
if (src == left) {
|
||||
recv(sockets_left_, out_ptr, nbytes);
|
||||
} else if (src == right) {
|
||||
recv(sockets_right_, out_ptr, nbytes);
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Recv only supported from direct neighbors "
|
||||
<< "but tried to recv from " << src << " to " << rank_
|
||||
<< std::endl;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
void all_sum(const array& input_, array& output) {
|
||||
// Make sure that the input is row contiguous
|
||||
array input = ensure_row_contiguous(input_);
|
||||
void all_sum(const array& input, array& output, Stream stream) {
|
||||
auto in_ptr = input.data<char>();
|
||||
auto out_ptr = output.data<char>();
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_output_array(output);
|
||||
encoder.dispatch([in_ptr, out_ptr, size = input.size(), this]() {
|
||||
// If the input data cannot be split into size_ segments then copy it and
|
||||
// all reduce a local buffer prefilled with 0s.
|
||||
size_t nbytes = size * sizeof(T);
|
||||
if (size < size_) {
|
||||
// TODO: Maybe allocate dynamically so we don't have the constraint
|
||||
// below?
|
||||
if (sizeof(T) * size_ > 1024) {
|
||||
std::ostringstream msg;
|
||||
msg << "Can't perform the ring all reduce of " << size
|
||||
<< " elements with a ring of size " << size_;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// If the input data cannot be split into size_ segments then copy it and
|
||||
// all reduce a local buffer prefilled with 0s.
|
||||
if (input.size() < size_) {
|
||||
// TODO: Maybe allocate dynamically so we don't have the constraint
|
||||
// below?
|
||||
if (input.itemsize() * size_ > 1024) {
|
||||
std::ostringstream msg;
|
||||
msg << "Can't perform the ring all reduce of " << output.size()
|
||||
<< " elements with a ring of size " << size_;
|
||||
throw std::runtime_error(msg.str());
|
||||
char buffer[1024];
|
||||
std::memset(buffer, 0, size_ * sizeof(T));
|
||||
std::memcpy(buffer, in_ptr, nbytes);
|
||||
all_sum_impl<T>(
|
||||
reinterpret_cast<T*>(buffers_.data()),
|
||||
reinterpret_cast<T*>(buffer),
|
||||
size_,
|
||||
sockets_right_[0],
|
||||
sockets_left_[0],
|
||||
-1);
|
||||
std::memcpy(out_ptr, buffer, nbytes);
|
||||
return;
|
||||
}
|
||||
|
||||
char buffer[1024];
|
||||
std::memset(buffer, 0, size_ * input.itemsize());
|
||||
std::memcpy(buffer, input.data<char>(), input.nbytes());
|
||||
all_sum_impl<T>(
|
||||
reinterpret_cast<T*>(buffers_.data()),
|
||||
reinterpret_cast<T*>(buffer),
|
||||
size_,
|
||||
sockets_right_[0],
|
||||
sockets_left_[0],
|
||||
-1);
|
||||
std::memcpy(output.data<char>(), buffer, output.nbytes());
|
||||
return;
|
||||
}
|
||||
// If not inplace all reduce then copy the input to the output first
|
||||
if (in_ptr != out_ptr) {
|
||||
std::memcpy(out_ptr, in_ptr, nbytes);
|
||||
}
|
||||
|
||||
// If not inplace all reduce then copy the input to the output first
|
||||
if (input.data<void>() != output.data<void>()) {
|
||||
std::memcpy(output.data<char>(), input.data<char>(), input.nbytes());
|
||||
}
|
||||
// Split the all reduces so that each member has at least 1 buffer to
|
||||
// send/recv per segment.
|
||||
constexpr size_t min_send_size = 262144;
|
||||
size_t n_reduces = std::max(
|
||||
std::min(
|
||||
sockets_right_.size() + sockets_left_.size(),
|
||||
nbytes / (size_ * min_send_size)),
|
||||
1UL);
|
||||
size_t step = ceildiv(size, n_reduces);
|
||||
std::vector<std::future<void>> all_sums;
|
||||
|
||||
// Split the all reduces so that each member has at least 1 buffer to
|
||||
// send/recv per segment.
|
||||
constexpr size_t min_send_size = 262144;
|
||||
size_t n_reduces = std::max(
|
||||
std::min(
|
||||
sockets_right_.size() + sockets_left_.size(),
|
||||
output.nbytes() / (size_ * min_send_size)),
|
||||
1UL);
|
||||
size_t step = ceildiv(output.size(), n_reduces);
|
||||
std::vector<std::future<void>> all_sums;
|
||||
|
||||
for (int i = 0; i < n_reduces; i++) {
|
||||
all_sums.emplace_back(pool_.enqueue(std::bind(
|
||||
&RingGroup::all_sum_impl<T>,
|
||||
this,
|
||||
reinterpret_cast<T*>(
|
||||
buffers_.data() + i * ALL_SUM_SIZE * ALL_SUM_BUFFERS),
|
||||
output.data<T>() + i * step,
|
||||
std::min(output.size(), (i + 1) * step) - i * step,
|
||||
sockets_right_[i / 2],
|
||||
sockets_left_[i / 2],
|
||||
(i % 2) ? -1 : 1)));
|
||||
}
|
||||
for (auto& f : all_sums) {
|
||||
f.wait();
|
||||
}
|
||||
for (int i = 0; i < n_reduces; i++) {
|
||||
all_sums.emplace_back(pool_.enqueue(std::bind(
|
||||
&RingGroup::all_sum_impl<T>,
|
||||
this,
|
||||
reinterpret_cast<T*>(
|
||||
buffers_.data() + i * ALL_SUM_SIZE * ALL_SUM_BUFFERS),
|
||||
reinterpret_cast<T*>(out_ptr) + i * step,
|
||||
std::min(size, (i + 1) * step) - i * step,
|
||||
sockets_right_[i / 2],
|
||||
sockets_left_[i / 2],
|
||||
(i % 2) ? -1 : 1)));
|
||||
}
|
||||
for (auto& f : all_sums) {
|
||||
f.wait();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -823,7 +827,8 @@ class RingGroup : public GroupImpl {
|
||||
recvs[b].wait();
|
||||
}
|
||||
|
||||
void send(const std::vector<int>& sockets, char* data, size_t data_size) {
|
||||
void
|
||||
send(const std::vector<int>& sockets, const char* data, size_t data_size) {
|
||||
size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size()));
|
||||
std::vector<std::future<void>> sends;
|
||||
for (int i = 0; i < sockets.size(); i++) {
|
||||
|
||||
Reference in New Issue
Block a user