mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Add async all reduce and donation
This commit is contained in:
parent
9f9cb7a2ef
commit
3a1df968cf
@ -252,8 +252,9 @@ class array {
|
||||
}
|
||||
|
||||
/** True indicates the arrays buffer is safe to reuse */
|
||||
bool is_donatable() const {
|
||||
return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1);
|
||||
bool is_donatable(int known_instances = 1) const {
|
||||
return array_desc_.use_count() == known_instances &&
|
||||
(array_desc_->data.use_count() == 1);
|
||||
}
|
||||
|
||||
/** The array's siblings. */
|
||||
|
@ -53,6 +53,10 @@ Stream communication_stream();
|
||||
|
||||
/* Perform an all reduce sum operation */
|
||||
void all_reduce_sum(Group group, const array& input, array& output);
|
||||
void all_reduce_sum(
|
||||
Group group,
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs);
|
||||
|
||||
/* Perform an all reduce sum operation */
|
||||
void all_gather(Group group, const array& input, array& output);
|
||||
|
@ -46,7 +46,9 @@ struct MPIWrapper {
|
||||
LOAD_SYMBOL(MPI_Comm_split, comm_split);
|
||||
LOAD_SYMBOL(MPI_Comm_free, comm_free);
|
||||
LOAD_SYMBOL(MPI_Allreduce, all_reduce);
|
||||
LOAD_SYMBOL(MPI_Iallreduce, async_all_reduce);
|
||||
LOAD_SYMBOL(MPI_Allgather, all_gather);
|
||||
LOAD_SYMBOL(MPI_Waitall, waitall);
|
||||
|
||||
// Objects
|
||||
LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_);
|
||||
@ -130,6 +132,14 @@ struct MPIWrapper {
|
||||
int (*finalize)();
|
||||
int (*rank)(MPI_Comm, int*);
|
||||
int (*size)(MPI_Comm, int*);
|
||||
int (*async_all_reduce)(
|
||||
const void*,
|
||||
void*,
|
||||
int,
|
||||
MPI_Datatype,
|
||||
MPI_Op,
|
||||
MPI_Comm,
|
||||
MPI_Request*);
|
||||
int (*all_reduce)(const void*, void*, int, MPI_Datatype, MPI_Op, MPI_Comm);
|
||||
int (*all_gather)(
|
||||
const void*,
|
||||
@ -139,6 +149,7 @@ struct MPIWrapper {
|
||||
int,
|
||||
MPI_Datatype,
|
||||
MPI_Comm);
|
||||
int (*waitall)(int, MPI_Request*, MPI_Status*);
|
||||
int (*comm_split)(MPI_Comm, int, int, MPI_Comm*);
|
||||
int (*comm_free)(MPI_Comm*);
|
||||
|
||||
@ -258,7 +269,8 @@ Stream communication_stream() {
|
||||
void all_reduce_sum(Group group, const array& input_, array& output) {
|
||||
array input = ensure_row_contiguous(input_);
|
||||
mpi().all_reduce(
|
||||
input.data<void>(),
|
||||
(input.data<void>() != output.data<void>()) ? input.data<void>()
|
||||
: MPI_IN_PLACE,
|
||||
output.data<void>(),
|
||||
input.size(),
|
||||
mpi().datatype(input),
|
||||
@ -266,6 +278,27 @@ void all_reduce_sum(Group group, const array& input_, array& output) {
|
||||
to_comm(group));
|
||||
}
|
||||
|
||||
void all_reduce_sum(
|
||||
Group group,
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
std::vector<MPI_Request> requests(inputs.size());
|
||||
std::vector<MPI_Status> statuses(inputs.size());
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
array input = ensure_row_contiguous(inputs[i]);
|
||||
mpi().async_all_reduce(
|
||||
(input.data<void>() != outputs[i].data<void>()) ? input.data<void>()
|
||||
: MPI_IN_PLACE,
|
||||
outputs[i].data<void>(),
|
||||
input.size(),
|
||||
mpi().datatype(input),
|
||||
mpi().op_sum(),
|
||||
to_comm(group),
|
||||
&requests[i]);
|
||||
}
|
||||
mpi().waitall(requests.size(), &requests[0], &statuses[0]);
|
||||
}
|
||||
|
||||
void all_gather(Group group, const array& input_, array& output) {
|
||||
array input = ensure_row_contiguous(input_);
|
||||
mpi().all_gather(
|
||||
|
@ -32,6 +32,10 @@ Stream communication_stream() {
|
||||
}
|
||||
|
||||
void all_reduce_sum(Group group, const array& input, array& output) {}
|
||||
void all_reduce_sum(
|
||||
Group group,
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {}
|
||||
void all_gather(Group group, const array& input, array& output) {}
|
||||
|
||||
} // namespace detail
|
||||
|
@ -18,17 +18,33 @@ Group to_group(std::optional<Group> group) {
|
||||
} // namespace
|
||||
|
||||
array all_reduce_sum(const array& x, std::optional<Group> group_) {
|
||||
return all_reduce_sum(std::vector<array>{x}, std::move(group_))[0];
|
||||
}
|
||||
|
||||
std::vector<array> all_reduce_sum(
|
||||
const std::vector<array>& xs,
|
||||
std::optional<Group> group_) {
|
||||
auto group = to_group(group_);
|
||||
|
||||
if (group.size() == 1) {
|
||||
return x;
|
||||
return xs;
|
||||
}
|
||||
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::vector<std::vector<int>> shapes;
|
||||
std::vector<Dtype> dtypes;
|
||||
shapes.reserve(xs.size());
|
||||
dtypes.reserve(xs.size());
|
||||
|
||||
for (const auto& x : xs) {
|
||||
shapes.push_back(x.shape());
|
||||
dtypes.push_back(x.dtype());
|
||||
}
|
||||
|
||||
return array::make_arrays(
|
||||
std::move(shapes),
|
||||
std::move(dtypes),
|
||||
std::make_shared<AllReduce>(group, AllReduce::Sum),
|
||||
{x});
|
||||
xs);
|
||||
}
|
||||
|
||||
array all_gather(const array& x, std::optional<Group> group_) {
|
||||
|
@ -9,6 +9,9 @@
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
array all_reduce_sum(const array& x, std::optional<Group> group = std::nullopt);
|
||||
std::vector<array> all_reduce_sum(
|
||||
const std::vector<array>& xs,
|
||||
std::optional<Group> group = std::nullopt);
|
||||
array all_gather(const array& x, std::optional<Group> group = std::nullopt);
|
||||
|
||||
} // namespace mlx::core::distributed
|
||||
|
@ -13,17 +13,30 @@ namespace mlx::core::distributed {
|
||||
void AllReduce::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
if (inputs[i].is_donatable(outputs.size())) {
|
||||
outputs[i].copy_shared_buffer(inputs[i]);
|
||||
} else {
|
||||
outputs[i].set_data(allocator::malloc_or_wait(outputs[i].nbytes()));
|
||||
}
|
||||
}
|
||||
|
||||
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
|
||||
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
distributed::detail::all_reduce_sum(group(), inputs[0], outputs[0]);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
if (inputs.size() == 1) {
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
distributed::detail::all_reduce_sum(group(), inputs[0], outputs[0]);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
} else {
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
distributed::detail::all_reduce_sum(group(), inputs, outputs);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "python/src/trees.h"
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
@ -60,19 +61,30 @@ void init_distributed(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"all_reduce_sum",
|
||||
&distributed::all_reduce_sum,
|
||||
"x"_a,
|
||||
[](const nb::args& args, std::optional<distributed::Group> group) {
|
||||
auto [xs, structure] =
|
||||
tree_flatten_with_structure((args.size() == 1) ? args[0] : args);
|
||||
auto ys = distributed::all_reduce_sum(xs, group);
|
||||
return tree_unflatten_from_structure(structure, ys);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def all_reduce_sum(x: array, *, group: Optional[Group] = None) -> array"),
|
||||
"def all_reduce_sum(*args, group: Optional[Group] = None) -> array"),
|
||||
R"pbdoc(
|
||||
All reduce sum.
|
||||
|
||||
Sum the ``x`` arrays from all processes in the group.
|
||||
Sum the passed :class:`array` or the tree of :class:`array` from all
|
||||
processes in the group.
|
||||
|
||||
.. note::
|
||||
|
||||
In order for the reduction to work, the iteration order of the passed
|
||||
trees of arrays must be the same across machines.
|
||||
|
||||
Args:
|
||||
x (array): Input array.
|
||||
*args (arrays or trees of arrays): Input arrays.
|
||||
group (Group): The group of processes that will participate in the
|
||||
reduction. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
|
Loading…
Reference in New Issue
Block a user