Add async all reduce and donation

This commit is contained in:
Angelos Katharopoulos
2024-05-29 22:54:32 -07:00
parent 9f9cb7a2ef
commit 3a1df968cf
8 changed files with 109 additions and 23 deletions

View File

@@ -6,6 +6,7 @@
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/ops.h"
#include "python/src/trees.h"
namespace nb = nanobind;
using namespace nb::literals;
@@ -60,19 +61,30 @@ void init_distributed(nb::module_& parent_module) {
m.def(
"all_reduce_sum",
&distributed::all_reduce_sum,
"x"_a,
[](const nb::args& args, std::optional<distributed::Group> group) {
auto [xs, structure] =
tree_flatten_with_structure((args.size() == 1) ? args[0] : args);
auto ys = distributed::all_reduce_sum(xs, group);
return tree_unflatten_from_structure(structure, ys);
},
nb::arg(),
nb::kw_only(),
"group"_a = nb::none(),
nb::sig(
"def all_reduce_sum(x: array, *, group: Optional[Group] = None) -> array"),
"def all_reduce_sum(*args, group: Optional[Group] = None) -> array"),
R"pbdoc(
All reduce sum.
Sum the ``x`` arrays from all processes in the group.
Sum the passed :class:`array` or the tree of :class:`array` from all
processes in the group.
.. note::
In order for the reduction to work, the iteration order of the passed
trees of arrays must be the same across machines.
Args:
x (array): Input array.
*args (arrays or trees of arrays): Input arrays.
group (Group): The group of processes that will participate in the
reduction. If set to ``None`` the global group is used. Default:
``None``.