mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-06 20:20:11 +08:00
Add async all reduce and donation
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "python/src/trees.h"
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
@@ -60,19 +61,30 @@ void init_distributed(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"all_reduce_sum",
|
||||
&distributed::all_reduce_sum,
|
||||
"x"_a,
|
||||
[](const nb::args& args, std::optional<distributed::Group> group) {
|
||||
auto [xs, structure] =
|
||||
tree_flatten_with_structure((args.size() == 1) ? args[0] : args);
|
||||
auto ys = distributed::all_reduce_sum(xs, group);
|
||||
return tree_unflatten_from_structure(structure, ys);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def all_reduce_sum(x: array, *, group: Optional[Group] = None) -> array"),
|
||||
"def all_reduce_sum(*args, group: Optional[Group] = None) -> array"),
|
||||
R"pbdoc(
|
||||
All reduce sum.
|
||||
|
||||
Sum the ``x`` arrays from all processes in the group.
|
||||
Sum the passed :class:`array` or the tree of :class:`array` from all
|
||||
processes in the group.
|
||||
|
||||
.. note::
|
||||
|
||||
In order for the reduction to work, the iteration order of the passed
|
||||
trees of arrays must be the same across machines.
|
||||
|
||||
Args:
|
||||
x (array): Input array.
|
||||
*args (arrays or trees of arrays): Input arrays.
|
||||
group (Group): The group of processes that will participate in the
|
||||
reduction. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
|
||||
Reference in New Issue
Block a user