mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Add async all reduce and donation
This commit is contained in:
parent
9f9cb7a2ef
commit
3a1df968cf
@ -252,8 +252,9 @@ class array {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** True indicates the arrays buffer is safe to reuse */
|
/** True indicates the arrays buffer is safe to reuse */
|
||||||
bool is_donatable() const {
|
bool is_donatable(int known_instances = 1) const {
|
||||||
return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1);
|
return array_desc_.use_count() == known_instances &&
|
||||||
|
(array_desc_->data.use_count() == 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** The array's siblings. */
|
/** The array's siblings. */
|
||||||
|
@ -53,6 +53,10 @@ Stream communication_stream();
|
|||||||
|
|
||||||
/* Perform an all reduce sum operation */
|
/* Perform an all reduce sum operation */
|
||||||
void all_reduce_sum(Group group, const array& input, array& output);
|
void all_reduce_sum(Group group, const array& input, array& output);
|
||||||
|
void all_reduce_sum(
|
||||||
|
Group group,
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs);
|
||||||
|
|
||||||
/* Perform an all reduce sum operation */
|
/* Perform an all reduce sum operation */
|
||||||
void all_gather(Group group, const array& input, array& output);
|
void all_gather(Group group, const array& input, array& output);
|
||||||
|
@ -46,7 +46,9 @@ struct MPIWrapper {
|
|||||||
LOAD_SYMBOL(MPI_Comm_split, comm_split);
|
LOAD_SYMBOL(MPI_Comm_split, comm_split);
|
||||||
LOAD_SYMBOL(MPI_Comm_free, comm_free);
|
LOAD_SYMBOL(MPI_Comm_free, comm_free);
|
||||||
LOAD_SYMBOL(MPI_Allreduce, all_reduce);
|
LOAD_SYMBOL(MPI_Allreduce, all_reduce);
|
||||||
|
LOAD_SYMBOL(MPI_Iallreduce, async_all_reduce);
|
||||||
LOAD_SYMBOL(MPI_Allgather, all_gather);
|
LOAD_SYMBOL(MPI_Allgather, all_gather);
|
||||||
|
LOAD_SYMBOL(MPI_Waitall, waitall);
|
||||||
|
|
||||||
// Objects
|
// Objects
|
||||||
LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_);
|
LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_);
|
||||||
@ -130,6 +132,14 @@ struct MPIWrapper {
|
|||||||
int (*finalize)();
|
int (*finalize)();
|
||||||
int (*rank)(MPI_Comm, int*);
|
int (*rank)(MPI_Comm, int*);
|
||||||
int (*size)(MPI_Comm, int*);
|
int (*size)(MPI_Comm, int*);
|
||||||
|
int (*async_all_reduce)(
|
||||||
|
const void*,
|
||||||
|
void*,
|
||||||
|
int,
|
||||||
|
MPI_Datatype,
|
||||||
|
MPI_Op,
|
||||||
|
MPI_Comm,
|
||||||
|
MPI_Request*);
|
||||||
int (*all_reduce)(const void*, void*, int, MPI_Datatype, MPI_Op, MPI_Comm);
|
int (*all_reduce)(const void*, void*, int, MPI_Datatype, MPI_Op, MPI_Comm);
|
||||||
int (*all_gather)(
|
int (*all_gather)(
|
||||||
const void*,
|
const void*,
|
||||||
@ -139,6 +149,7 @@ struct MPIWrapper {
|
|||||||
int,
|
int,
|
||||||
MPI_Datatype,
|
MPI_Datatype,
|
||||||
MPI_Comm);
|
MPI_Comm);
|
||||||
|
int (*waitall)(int, MPI_Request*, MPI_Status*);
|
||||||
int (*comm_split)(MPI_Comm, int, int, MPI_Comm*);
|
int (*comm_split)(MPI_Comm, int, int, MPI_Comm*);
|
||||||
int (*comm_free)(MPI_Comm*);
|
int (*comm_free)(MPI_Comm*);
|
||||||
|
|
||||||
@ -258,7 +269,8 @@ Stream communication_stream() {
|
|||||||
void all_reduce_sum(Group group, const array& input_, array& output) {
|
void all_reduce_sum(Group group, const array& input_, array& output) {
|
||||||
array input = ensure_row_contiguous(input_);
|
array input = ensure_row_contiguous(input_);
|
||||||
mpi().all_reduce(
|
mpi().all_reduce(
|
||||||
input.data<void>(),
|
(input.data<void>() != output.data<void>()) ? input.data<void>()
|
||||||
|
: MPI_IN_PLACE,
|
||||||
output.data<void>(),
|
output.data<void>(),
|
||||||
input.size(),
|
input.size(),
|
||||||
mpi().datatype(input),
|
mpi().datatype(input),
|
||||||
@ -266,6 +278,27 @@ void all_reduce_sum(Group group, const array& input_, array& output) {
|
|||||||
to_comm(group));
|
to_comm(group));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void all_reduce_sum(
|
||||||
|
Group group,
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
std::vector<MPI_Request> requests(inputs.size());
|
||||||
|
std::vector<MPI_Status> statuses(inputs.size());
|
||||||
|
for (int i = 0; i < inputs.size(); i++) {
|
||||||
|
array input = ensure_row_contiguous(inputs[i]);
|
||||||
|
mpi().async_all_reduce(
|
||||||
|
(input.data<void>() != outputs[i].data<void>()) ? input.data<void>()
|
||||||
|
: MPI_IN_PLACE,
|
||||||
|
outputs[i].data<void>(),
|
||||||
|
input.size(),
|
||||||
|
mpi().datatype(input),
|
||||||
|
mpi().op_sum(),
|
||||||
|
to_comm(group),
|
||||||
|
&requests[i]);
|
||||||
|
}
|
||||||
|
mpi().waitall(requests.size(), &requests[0], &statuses[0]);
|
||||||
|
}
|
||||||
|
|
||||||
void all_gather(Group group, const array& input_, array& output) {
|
void all_gather(Group group, const array& input_, array& output) {
|
||||||
array input = ensure_row_contiguous(input_);
|
array input = ensure_row_contiguous(input_);
|
||||||
mpi().all_gather(
|
mpi().all_gather(
|
||||||
|
@ -32,6 +32,10 @@ Stream communication_stream() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void all_reduce_sum(Group group, const array& input, array& output) {}
|
void all_reduce_sum(Group group, const array& input, array& output) {}
|
||||||
|
void all_reduce_sum(
|
||||||
|
Group group,
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {}
|
||||||
void all_gather(Group group, const array& input, array& output) {}
|
void all_gather(Group group, const array& input, array& output) {}
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
@ -18,17 +18,33 @@ Group to_group(std::optional<Group> group) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
array all_reduce_sum(const array& x, std::optional<Group> group_) {
|
array all_reduce_sum(const array& x, std::optional<Group> group_) {
|
||||||
|
return all_reduce_sum(std::vector<array>{x}, std::move(group_))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> all_reduce_sum(
|
||||||
|
const std::vector<array>& xs,
|
||||||
|
std::optional<Group> group_) {
|
||||||
auto group = to_group(group_);
|
auto group = to_group(group_);
|
||||||
|
|
||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
return x;
|
return xs;
|
||||||
}
|
}
|
||||||
|
|
||||||
return array(
|
std::vector<std::vector<int>> shapes;
|
||||||
x.shape(),
|
std::vector<Dtype> dtypes;
|
||||||
x.dtype(),
|
shapes.reserve(xs.size());
|
||||||
|
dtypes.reserve(xs.size());
|
||||||
|
|
||||||
|
for (const auto& x : xs) {
|
||||||
|
shapes.push_back(x.shape());
|
||||||
|
dtypes.push_back(x.dtype());
|
||||||
|
}
|
||||||
|
|
||||||
|
return array::make_arrays(
|
||||||
|
std::move(shapes),
|
||||||
|
std::move(dtypes),
|
||||||
std::make_shared<AllReduce>(group, AllReduce::Sum),
|
std::make_shared<AllReduce>(group, AllReduce::Sum),
|
||||||
{x});
|
xs);
|
||||||
}
|
}
|
||||||
|
|
||||||
array all_gather(const array& x, std::optional<Group> group_) {
|
array all_gather(const array& x, std::optional<Group> group_) {
|
||||||
|
@ -9,6 +9,9 @@
|
|||||||
namespace mlx::core::distributed {
|
namespace mlx::core::distributed {
|
||||||
|
|
||||||
array all_reduce_sum(const array& x, std::optional<Group> group = std::nullopt);
|
array all_reduce_sum(const array& x, std::optional<Group> group = std::nullopt);
|
||||||
|
std::vector<array> all_reduce_sum(
|
||||||
|
const std::vector<array>& xs,
|
||||||
|
std::optional<Group> group = std::nullopt);
|
||||||
array all_gather(const array& x, std::optional<Group> group = std::nullopt);
|
array all_gather(const array& x, std::optional<Group> group = std::nullopt);
|
||||||
|
|
||||||
} // namespace mlx::core::distributed
|
} // namespace mlx::core::distributed
|
||||||
|
@ -13,11 +13,15 @@ namespace mlx::core::distributed {
|
|||||||
void AllReduce::eval_cpu(
|
void AllReduce::eval_cpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
assert(inputs.size() == 1);
|
for (int i = 0; i < inputs.size(); i++) {
|
||||||
assert(outputs.size() == 1);
|
if (inputs[i].is_donatable(outputs.size())) {
|
||||||
|
outputs[i].copy_shared_buffer(inputs[i]);
|
||||||
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
|
} else {
|
||||||
|
outputs[i].set_data(allocator::malloc_or_wait(outputs[i].nbytes()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inputs.size() == 1) {
|
||||||
switch (reduce_type_) {
|
switch (reduce_type_) {
|
||||||
case Sum:
|
case Sum:
|
||||||
distributed::detail::all_reduce_sum(group(), inputs[0], outputs[0]);
|
distributed::detail::all_reduce_sum(group(), inputs[0], outputs[0]);
|
||||||
@ -25,6 +29,15 @@ void AllReduce::eval_cpu(
|
|||||||
default:
|
default:
|
||||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
switch (reduce_type_) {
|
||||||
|
case Sum:
|
||||||
|
distributed::detail::all_reduce_sum(group(), inputs, outputs);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
|
std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/ops.h"
|
#include "mlx/distributed/ops.h"
|
||||||
|
#include "python/src/trees.h"
|
||||||
|
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace nb::literals;
|
using namespace nb::literals;
|
||||||
@ -60,19 +61,30 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"all_reduce_sum",
|
"all_reduce_sum",
|
||||||
&distributed::all_reduce_sum,
|
[](const nb::args& args, std::optional<distributed::Group> group) {
|
||||||
"x"_a,
|
auto [xs, structure] =
|
||||||
|
tree_flatten_with_structure((args.size() == 1) ? args[0] : args);
|
||||||
|
auto ys = distributed::all_reduce_sum(xs, group);
|
||||||
|
return tree_unflatten_from_structure(structure, ys);
|
||||||
|
},
|
||||||
|
nb::arg(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"group"_a = nb::none(),
|
"group"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def all_reduce_sum(x: array, *, group: Optional[Group] = None) -> array"),
|
"def all_reduce_sum(*args, group: Optional[Group] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
All reduce sum.
|
All reduce sum.
|
||||||
|
|
||||||
Sum the ``x`` arrays from all processes in the group.
|
Sum the passed :class:`array` or the tree of :class:`array` from all
|
||||||
|
processes in the group.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
In order for the reduction to work, the iteration order of the passed
|
||||||
|
trees of arrays must be the same across machines.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (array): Input array.
|
*args (arrays or trees of arrays): Input arrays.
|
||||||
group (Group): The group of processes that will participate in the
|
group (Group): The group of processes that will participate in the
|
||||||
reduction. If set to ``None`` the global group is used. Default:
|
reduction. If set to ``None`` the global group is used. Default:
|
||||||
``None``.
|
``None``.
|
||||||
|
Loading…
Reference in New Issue
Block a user