Add async all reduce and donation

This commit is contained in:
Angelos Katharopoulos 2024-05-29 22:54:32 -07:00
parent 9f9cb7a2ef
commit 3a1df968cf
8 changed files with 109 additions and 23 deletions

View File

@ -252,8 +252,9 @@ class array {
} }
/** True indicates the arrays buffer is safe to reuse */ /** True indicates the arrays buffer is safe to reuse */
bool is_donatable() const { bool is_donatable(int known_instances = 1) const {
return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1); return array_desc_.use_count() == known_instances &&
(array_desc_->data.use_count() == 1);
} }
/** The array's siblings. */ /** The array's siblings. */

View File

@ -53,6 +53,10 @@ Stream communication_stream();
/* Perform an all reduce sum operation */ /* Perform an all reduce sum operation */
void all_reduce_sum(Group group, const array& input, array& output); void all_reduce_sum(Group group, const array& input, array& output);
void all_reduce_sum(
Group group,
const std::vector<array>& inputs,
std::vector<array>& outputs);
/* Perform an all reduce sum operation */ /* Perform an all reduce sum operation */
void all_gather(Group group, const array& input, array& output); void all_gather(Group group, const array& input, array& output);

View File

@ -46,7 +46,9 @@ struct MPIWrapper {
LOAD_SYMBOL(MPI_Comm_split, comm_split); LOAD_SYMBOL(MPI_Comm_split, comm_split);
LOAD_SYMBOL(MPI_Comm_free, comm_free); LOAD_SYMBOL(MPI_Comm_free, comm_free);
LOAD_SYMBOL(MPI_Allreduce, all_reduce); LOAD_SYMBOL(MPI_Allreduce, all_reduce);
LOAD_SYMBOL(MPI_Iallreduce, async_all_reduce);
LOAD_SYMBOL(MPI_Allgather, all_gather); LOAD_SYMBOL(MPI_Allgather, all_gather);
LOAD_SYMBOL(MPI_Waitall, waitall);
// Objects // Objects
LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_); LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_);
@ -130,6 +132,14 @@ struct MPIWrapper {
int (*finalize)(); int (*finalize)();
int (*rank)(MPI_Comm, int*); int (*rank)(MPI_Comm, int*);
int (*size)(MPI_Comm, int*); int (*size)(MPI_Comm, int*);
int (*async_all_reduce)(
const void*,
void*,
int,
MPI_Datatype,
MPI_Op,
MPI_Comm,
MPI_Request*);
int (*all_reduce)(const void*, void*, int, MPI_Datatype, MPI_Op, MPI_Comm); int (*all_reduce)(const void*, void*, int, MPI_Datatype, MPI_Op, MPI_Comm);
int (*all_gather)( int (*all_gather)(
const void*, const void*,
@ -139,6 +149,7 @@ struct MPIWrapper {
int, int,
MPI_Datatype, MPI_Datatype,
MPI_Comm); MPI_Comm);
int (*waitall)(int, MPI_Request*, MPI_Status*);
int (*comm_split)(MPI_Comm, int, int, MPI_Comm*); int (*comm_split)(MPI_Comm, int, int, MPI_Comm*);
int (*comm_free)(MPI_Comm*); int (*comm_free)(MPI_Comm*);
@ -258,7 +269,8 @@ Stream communication_stream() {
void all_reduce_sum(Group group, const array& input_, array& output) { void all_reduce_sum(Group group, const array& input_, array& output) {
array input = ensure_row_contiguous(input_); array input = ensure_row_contiguous(input_);
mpi().all_reduce( mpi().all_reduce(
input.data<void>(), (input.data<void>() != output.data<void>()) ? input.data<void>()
: MPI_IN_PLACE,
output.data<void>(), output.data<void>(),
input.size(), input.size(),
mpi().datatype(input), mpi().datatype(input),
@ -266,6 +278,27 @@ void all_reduce_sum(Group group, const array& input_, array& output) {
to_comm(group)); to_comm(group));
} }
void all_reduce_sum(
Group group,
const std::vector<array>& inputs,
std::vector<array>& outputs) {
std::vector<MPI_Request> requests(inputs.size());
std::vector<MPI_Status> statuses(inputs.size());
for (int i = 0; i < inputs.size(); i++) {
array input = ensure_row_contiguous(inputs[i]);
mpi().async_all_reduce(
(input.data<void>() != outputs[i].data<void>()) ? input.data<void>()
: MPI_IN_PLACE,
outputs[i].data<void>(),
input.size(),
mpi().datatype(input),
mpi().op_sum(),
to_comm(group),
&requests[i]);
}
mpi().waitall(requests.size(), &requests[0], &statuses[0]);
}
void all_gather(Group group, const array& input_, array& output) { void all_gather(Group group, const array& input_, array& output) {
array input = ensure_row_contiguous(input_); array input = ensure_row_contiguous(input_);
mpi().all_gather( mpi().all_gather(

View File

@ -32,6 +32,10 @@ Stream communication_stream() {
} }
void all_reduce_sum(Group group, const array& input, array& output) {} void all_reduce_sum(Group group, const array& input, array& output) {}
void all_reduce_sum(
Group group,
const std::vector<array>& inputs,
std::vector<array>& outputs) {}
void all_gather(Group group, const array& input, array& output) {} void all_gather(Group group, const array& input, array& output) {}
} // namespace detail } // namespace detail

View File

@ -18,17 +18,33 @@ Group to_group(std::optional<Group> group) {
} // namespace } // namespace
array all_reduce_sum(const array& x, std::optional<Group> group_) { array all_reduce_sum(const array& x, std::optional<Group> group_) {
return all_reduce_sum(std::vector<array>{x}, std::move(group_))[0];
}
std::vector<array> all_reduce_sum(
const std::vector<array>& xs,
std::optional<Group> group_) {
auto group = to_group(group_); auto group = to_group(group_);
if (group.size() == 1) { if (group.size() == 1) {
return x; return xs;
} }
return array( std::vector<std::vector<int>> shapes;
x.shape(), std::vector<Dtype> dtypes;
x.dtype(), shapes.reserve(xs.size());
dtypes.reserve(xs.size());
for (const auto& x : xs) {
shapes.push_back(x.shape());
dtypes.push_back(x.dtype());
}
return array::make_arrays(
std::move(shapes),
std::move(dtypes),
std::make_shared<AllReduce>(group, AllReduce::Sum), std::make_shared<AllReduce>(group, AllReduce::Sum),
{x}); xs);
} }
array all_gather(const array& x, std::optional<Group> group_) { array all_gather(const array& x, std::optional<Group> group_) {

View File

@ -9,6 +9,9 @@
namespace mlx::core::distributed { namespace mlx::core::distributed {
array all_reduce_sum(const array& x, std::optional<Group> group = std::nullopt); array all_reduce_sum(const array& x, std::optional<Group> group = std::nullopt);
std::vector<array> all_reduce_sum(
const std::vector<array>& xs,
std::optional<Group> group = std::nullopt);
array all_gather(const array& x, std::optional<Group> group = std::nullopt); array all_gather(const array& x, std::optional<Group> group = std::nullopt);
} // namespace mlx::core::distributed } // namespace mlx::core::distributed

View File

@ -13,17 +13,30 @@ namespace mlx::core::distributed {
void AllReduce::eval_cpu( void AllReduce::eval_cpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
assert(inputs.size() == 1); for (int i = 0; i < inputs.size(); i++) {
assert(outputs.size() == 1); if (inputs[i].is_donatable(outputs.size())) {
outputs[i].copy_shared_buffer(inputs[i]);
} else {
outputs[i].set_data(allocator::malloc_or_wait(outputs[i].nbytes()));
}
}
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes())); if (inputs.size() == 1) {
switch (reduce_type_) {
switch (reduce_type_) { case Sum:
case Sum: distributed::detail::all_reduce_sum(group(), inputs[0], outputs[0]);
distributed::detail::all_reduce_sum(group(), inputs[0], outputs[0]); break;
break; default:
default: throw std::runtime_error("Only all reduce sum is supported for now");
throw std::runtime_error("Only all reduce sum is supported for now"); }
} else {
switch (reduce_type_) {
case Sum:
distributed::detail::all_reduce_sum(group(), inputs, outputs);
break;
default:
throw std::runtime_error("Only all reduce sum is supported for now");
}
} }
} }

View File

@ -6,6 +6,7 @@
#include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed.h"
#include "mlx/distributed/ops.h" #include "mlx/distributed/ops.h"
#include "python/src/trees.h"
namespace nb = nanobind; namespace nb = nanobind;
using namespace nb::literals; using namespace nb::literals;
@ -60,19 +61,30 @@ void init_distributed(nb::module_& parent_module) {
m.def( m.def(
"all_reduce_sum", "all_reduce_sum",
&distributed::all_reduce_sum, [](const nb::args& args, std::optional<distributed::Group> group) {
"x"_a, auto [xs, structure] =
tree_flatten_with_structure((args.size() == 1) ? args[0] : args);
auto ys = distributed::all_reduce_sum(xs, group);
return tree_unflatten_from_structure(structure, ys);
},
nb::arg(),
nb::kw_only(), nb::kw_only(),
"group"_a = nb::none(), "group"_a = nb::none(),
nb::sig( nb::sig(
"def all_reduce_sum(x: array, *, group: Optional[Group] = None) -> array"), "def all_reduce_sum(*args, group: Optional[Group] = None) -> array"),
R"pbdoc( R"pbdoc(
All reduce sum. All reduce sum.
Sum the ``x`` arrays from all processes in the group. Sum the passed :class:`array` or the tree of :class:`array` from all
processes in the group.
.. note::
In order for the reduction to work, the iteration order of the passed
trees of arrays must be the same across machines.
Args: Args:
x (array): Input array. *args (arrays or trees of arrays): Input arrays.
group (Group): The group of processes that will participate in the group (Group): The group of processes that will participate in the
reduction. If set to ``None`` the global group is used. Default: reduction. If set to ``None`` the global group is used. Default:
``None``. ``None``.