Add async all reduce and donation

This commit is contained in:
Angelos Katharopoulos 2024-05-29 22:54:32 -07:00
parent 9f9cb7a2ef
commit 3a1df968cf
8 changed files with 109 additions and 23 deletions

View File

@ -252,8 +252,9 @@ class array {
}
/** True indicates the arrays buffer is safe to reuse */
bool is_donatable() const {
return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1);
bool is_donatable(int known_instances = 1) const {
return array_desc_.use_count() == known_instances &&
(array_desc_->data.use_count() == 1);
}
/** The array's siblings. */

View File

@ -53,6 +53,10 @@ Stream communication_stream();
/* Perform an all reduce sum operation */
void all_reduce_sum(Group group, const array& input, array& output);
void all_reduce_sum(
Group group,
const std::vector<array>& inputs,
std::vector<array>& outputs);
/* Perform an all reduce sum operation */
void all_gather(Group group, const array& input, array& output);

View File

@ -46,7 +46,9 @@ struct MPIWrapper {
LOAD_SYMBOL(MPI_Comm_split, comm_split);
LOAD_SYMBOL(MPI_Comm_free, comm_free);
LOAD_SYMBOL(MPI_Allreduce, all_reduce);
LOAD_SYMBOL(MPI_Iallreduce, async_all_reduce);
LOAD_SYMBOL(MPI_Allgather, all_gather);
LOAD_SYMBOL(MPI_Waitall, waitall);
// Objects
LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_);
@ -130,6 +132,14 @@ struct MPIWrapper {
int (*finalize)();
int (*rank)(MPI_Comm, int*);
int (*size)(MPI_Comm, int*);
int (*async_all_reduce)(
const void*,
void*,
int,
MPI_Datatype,
MPI_Op,
MPI_Comm,
MPI_Request*);
int (*all_reduce)(const void*, void*, int, MPI_Datatype, MPI_Op, MPI_Comm);
int (*all_gather)(
const void*,
@ -139,6 +149,7 @@ struct MPIWrapper {
int,
MPI_Datatype,
MPI_Comm);
int (*waitall)(int, MPI_Request*, MPI_Status*);
int (*comm_split)(MPI_Comm, int, int, MPI_Comm*);
int (*comm_free)(MPI_Comm*);
@ -258,7 +269,8 @@ Stream communication_stream() {
void all_reduce_sum(Group group, const array& input_, array& output) {
array input = ensure_row_contiguous(input_);
mpi().all_reduce(
input.data<void>(),
(input.data<void>() != output.data<void>()) ? input.data<void>()
: MPI_IN_PLACE,
output.data<void>(),
input.size(),
mpi().datatype(input),
@ -266,6 +278,27 @@ void all_reduce_sum(Group group, const array& input_, array& output) {
to_comm(group));
}
void all_reduce_sum(
Group group,
const std::vector<array>& inputs,
std::vector<array>& outputs) {
std::vector<MPI_Request> requests(inputs.size());
std::vector<MPI_Status> statuses(inputs.size());
for (int i = 0; i < inputs.size(); i++) {
array input = ensure_row_contiguous(inputs[i]);
mpi().async_all_reduce(
(input.data<void>() != outputs[i].data<void>()) ? input.data<void>()
: MPI_IN_PLACE,
outputs[i].data<void>(),
input.size(),
mpi().datatype(input),
mpi().op_sum(),
to_comm(group),
&requests[i]);
}
mpi().waitall(requests.size(), &requests[0], &statuses[0]);
}
void all_gather(Group group, const array& input_, array& output) {
array input = ensure_row_contiguous(input_);
mpi().all_gather(

View File

@ -32,6 +32,10 @@ Stream communication_stream() {
}
void all_reduce_sum(Group group, const array& input, array& output) {}
void all_reduce_sum(
Group group,
const std::vector<array>& inputs,
std::vector<array>& outputs) {}
void all_gather(Group group, const array& input, array& output) {}
} // namespace detail

View File

@ -18,17 +18,33 @@ Group to_group(std::optional<Group> group) {
} // namespace
array all_reduce_sum(const array& x, std::optional<Group> group_) {
return all_reduce_sum(std::vector<array>{x}, std::move(group_))[0];
}
std::vector<array> all_reduce_sum(
const std::vector<array>& xs,
std::optional<Group> group_) {
auto group = to_group(group_);
if (group.size() == 1) {
return x;
return xs;
}
return array(
x.shape(),
x.dtype(),
std::vector<std::vector<int>> shapes;
std::vector<Dtype> dtypes;
shapes.reserve(xs.size());
dtypes.reserve(xs.size());
for (const auto& x : xs) {
shapes.push_back(x.shape());
dtypes.push_back(x.dtype());
}
return array::make_arrays(
std::move(shapes),
std::move(dtypes),
std::make_shared<AllReduce>(group, AllReduce::Sum),
{x});
xs);
}
array all_gather(const array& x, std::optional<Group> group_) {

View File

@ -9,6 +9,9 @@
namespace mlx::core::distributed {
array all_reduce_sum(const array& x, std::optional<Group> group = std::nullopt);
std::vector<array> all_reduce_sum(
const std::vector<array>& xs,
std::optional<Group> group = std::nullopt);
array all_gather(const array& x, std::optional<Group> group = std::nullopt);
} // namespace mlx::core::distributed

View File

@ -13,11 +13,15 @@ namespace mlx::core::distributed {
void AllReduce::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
for (int i = 0; i < inputs.size(); i++) {
if (inputs[i].is_donatable(outputs.size())) {
outputs[i].copy_shared_buffer(inputs[i]);
} else {
outputs[i].set_data(allocator::malloc_or_wait(outputs[i].nbytes()));
}
}
if (inputs.size() == 1) {
switch (reduce_type_) {
case Sum:
distributed::detail::all_reduce_sum(group(), inputs[0], outputs[0]);
@ -25,6 +29,15 @@ void AllReduce::eval_cpu(
default:
throw std::runtime_error("Only all reduce sum is supported for now");
}
} else {
switch (reduce_type_) {
case Sum:
distributed::detail::all_reduce_sum(group(), inputs, outputs);
break;
default:
throw std::runtime_error("Only all reduce sum is supported for now");
}
}
}
std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(

View File

@ -6,6 +6,7 @@
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/ops.h"
#include "python/src/trees.h"
namespace nb = nanobind;
using namespace nb::literals;
@ -60,19 +61,30 @@ void init_distributed(nb::module_& parent_module) {
m.def(
"all_reduce_sum",
&distributed::all_reduce_sum,
"x"_a,
[](const nb::args& args, std::optional<distributed::Group> group) {
auto [xs, structure] =
tree_flatten_with_structure((args.size() == 1) ? args[0] : args);
auto ys = distributed::all_reduce_sum(xs, group);
return tree_unflatten_from_structure(structure, ys);
},
nb::arg(),
nb::kw_only(),
"group"_a = nb::none(),
nb::sig(
"def all_reduce_sum(x: array, *, group: Optional[Group] = None) -> array"),
"def all_reduce_sum(*args, group: Optional[Group] = None) -> array"),
R"pbdoc(
All reduce sum.
Sum the ``x`` arrays from all processes in the group.
Sum the passed :class:`array` or the tree of :class:`array` from all
processes in the group.
.. note::
In order for the reduction to work, the iteration order of the passed
trees of arrays must be the same across machines.
Args:
x (array): Input array.
*args (arrays or trees of arrays): Input arrays.
group (Group): The group of processes that will participate in the
reduction. If set to ``None`` the global group is used. Default:
``None``.