diff --git a/mlx/array.h b/mlx/array.h index 96b6b971e..7b0a14fa6 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -252,8 +252,9 @@ class array { } /** True indicates the arrays buffer is safe to reuse */ - bool is_donatable() const { - return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1); + bool is_donatable(int known_instances = 1) const { + return array_desc_.use_count() == known_instances && + (array_desc_->data.use_count() == 1); } /** The array's siblings. */ diff --git a/mlx/distributed/distributed.h b/mlx/distributed/distributed.h index cad75b396..bfbd8b399 100644 --- a/mlx/distributed/distributed.h +++ b/mlx/distributed/distributed.h @@ -53,6 +53,10 @@ Stream communication_stream(); /* Perform an all reduce sum operation */ void all_reduce_sum(Group group, const array& input, array& output); +void all_reduce_sum( + Group group, + const std::vector& inputs, + std::vector& outputs); /* Perform an all reduce sum operation */ void all_gather(Group group, const array& input, array& output); diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index 3d1818195..c76e8d004 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -46,7 +46,9 @@ struct MPIWrapper { LOAD_SYMBOL(MPI_Comm_split, comm_split); LOAD_SYMBOL(MPI_Comm_free, comm_free); LOAD_SYMBOL(MPI_Allreduce, all_reduce); + LOAD_SYMBOL(MPI_Iallreduce, async_all_reduce); LOAD_SYMBOL(MPI_Allgather, all_gather); + LOAD_SYMBOL(MPI_Waitall, waitall); // Objects LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_); @@ -130,6 +132,14 @@ struct MPIWrapper { int (*finalize)(); int (*rank)(MPI_Comm, int*); int (*size)(MPI_Comm, int*); + int (*async_all_reduce)( + const void*, + void*, + int, + MPI_Datatype, + MPI_Op, + MPI_Comm, + MPI_Request*); int (*all_reduce)(const void*, void*, int, MPI_Datatype, MPI_Op, MPI_Comm); int (*all_gather)( const void*, @@ -139,6 +149,7 @@ struct MPIWrapper { int, MPI_Datatype, MPI_Comm); + int (*waitall)(int, MPI_Request*, MPI_Status*); int (*comm_split)(MPI_Comm, int, int, MPI_Comm*); int (*comm_free)(MPI_Comm*); @@ -258,7 +269,8 @@ Stream communication_stream() { void all_reduce_sum(Group group, const array& input_, array& output) { array input = ensure_row_contiguous(input_); mpi().all_reduce( - input.data(), + (input.data() != output.data()) ? input.data() + : MPI_IN_PLACE, output.data(), input.size(), mpi().datatype(input), @@ -266,6 +278,27 @@ void all_reduce_sum(Group group, const array& input_, array& output) { to_comm(group)); } +void all_reduce_sum( + Group group, + const std::vector& inputs, + std::vector& outputs) { + std::vector requests(inputs.size()); + std::vector statuses(inputs.size()); + for (int i = 0; i < inputs.size(); i++) { + array input = ensure_row_contiguous(inputs[i]); + mpi().async_all_reduce( + (input.data() != outputs[i].data()) ? input.data() + : MPI_IN_PLACE, + outputs[i].data(), + input.size(), + mpi().datatype(input), + mpi().op_sum(), + to_comm(group), + &requests[i]); + } + mpi().waitall(requests.size(), &requests[0], &statuses[0]); +} + void all_gather(Group group, const array& input_, array& output) { array input = ensure_row_contiguous(input_); mpi().all_gather( diff --git a/mlx/distributed/no_distributed.cpp b/mlx/distributed/no_distributed.cpp index d85428496..2d1272d9c 100644 --- a/mlx/distributed/no_distributed.cpp +++ b/mlx/distributed/no_distributed.cpp @@ -32,6 +32,10 @@ Stream communication_stream() { } void all_reduce_sum(Group group, const array& input, array& output) {} +void all_reduce_sum( + Group group, + const std::vector& inputs, + std::vector& outputs) {} void all_gather(Group group, const array& input, array& output) {} } // namespace detail diff --git a/mlx/distributed/ops.cpp b/mlx/distributed/ops.cpp index 69cf196cb..e9f7c4227 100644 --- a/mlx/distributed/ops.cpp +++ b/mlx/distributed/ops.cpp @@ -18,17 +18,33 @@ Group to_group(std::optional group) { } // namespace array all_reduce_sum(const array& x, std::optional group_) { + return all_reduce_sum(std::vector{x}, std::move(group_))[0]; +} + +std::vector all_reduce_sum( + const std::vector& xs, + std::optional group_) { auto group = to_group(group_); if (group.size() == 1) { - return x; + return xs; } - return array( - x.shape(), - x.dtype(), + std::vector> shapes; + std::vector dtypes; + shapes.reserve(xs.size()); + dtypes.reserve(xs.size()); + + for (const auto& x : xs) { + shapes.push_back(x.shape()); + dtypes.push_back(x.dtype()); + } + + return array::make_arrays( + std::move(shapes), + std::move(dtypes), std::make_shared(group, AllReduce::Sum), - {x}); + xs); } array all_gather(const array& x, std::optional group_) { diff --git a/mlx/distributed/ops.h b/mlx/distributed/ops.h index 1afe8dcc8..29dba6f92 100644 --- a/mlx/distributed/ops.h +++ b/mlx/distributed/ops.h @@ -9,6 +9,9 @@ namespace mlx::core::distributed { array all_reduce_sum(const array& x, std::optional group = std::nullopt); +std::vector all_reduce_sum( + const std::vector& xs, + std::optional group = std::nullopt); array all_gather(const array& x, std::optional group = std::nullopt); } // namespace mlx::core::distributed diff --git a/mlx/distributed/primitives.cpp b/mlx/distributed/primitives.cpp index b20fde605..02cff91bf 100644 --- a/mlx/distributed/primitives.cpp +++ b/mlx/distributed/primitives.cpp @@ -13,17 +13,30 @@ namespace mlx::core::distributed { void AllReduce::eval_cpu( const std::vector& inputs, std::vector& outputs) { - assert(inputs.size() == 1); - assert(outputs.size() == 1); + for (int i = 0; i < inputs.size(); i++) { + if (inputs[i].is_donatable(outputs.size())) { + outputs[i].copy_shared_buffer(inputs[i]); + } else { + outputs[i].set_data(allocator::malloc_or_wait(outputs[i].nbytes())); + } + } - outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes())); - - switch (reduce_type_) { - case Sum: - distributed::detail::all_reduce_sum(group(), inputs[0], outputs[0]); - break; - default: - throw std::runtime_error("Only all reduce sum is supported for now"); + if (inputs.size() == 1) { + switch (reduce_type_) { + case Sum: + distributed::detail::all_reduce_sum(group(), inputs[0], outputs[0]); + break; + default: + throw std::runtime_error("Only all reduce sum is supported for now"); + } + } else { + switch (reduce_type_) { + case Sum: + distributed::detail::all_reduce_sum(group(), inputs, outputs); + break; + default: + throw std::runtime_error("Only all reduce sum is supported for now"); + } } } diff --git a/python/src/distributed.cpp b/python/src/distributed.cpp index 069b5a885..f036f65d7 100644 --- a/python/src/distributed.cpp +++ b/python/src/distributed.cpp @@ -6,6 +6,7 @@ #include "mlx/distributed/distributed.h" #include "mlx/distributed/ops.h" +#include "python/src/trees.h" namespace nb = nanobind; using namespace nb::literals; @@ -60,19 +61,30 @@ void init_distributed(nb::module_& parent_module) { m.def( "all_reduce_sum", - &distributed::all_reduce_sum, - "x"_a, + [](const nb::args& args, std::optional group) { + auto [xs, structure] = + tree_flatten_with_structure((args.size() == 1) ? args[0] : args); + auto ys = distributed::all_reduce_sum(xs, group); + return tree_unflatten_from_structure(structure, ys); + }, + nb::arg(), nb::kw_only(), "group"_a = nb::none(), nb::sig( - "def all_reduce_sum(x: array, *, group: Optional[Group] = None) -> array"), + "def all_reduce_sum(*args, group: Optional[Group] = None) -> array"), R"pbdoc( All reduce sum. - Sum the ``x`` arrays from all processes in the group. + Sum the passed :class:`array` or the tree of :class:`array` from all + processes in the group. + + .. note:: + + In order for the reduction to work, the iteration order of the passed + trees of arrays must be the same across machines. Args: - x (array): Input array. + *args (arrays or trees of arrays): Input arrays. group (Group): The group of processes that will participate in the reduction. If set to ``None`` the global group is used. Default: ``None``.