mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Add async all reduce and donation
This commit is contained in:
		@@ -252,8 +252,9 @@ class array {
 | 
				
			|||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /** True indicates the arrays buffer is safe to reuse */
 | 
					  /** True indicates the arrays buffer is safe to reuse */
 | 
				
			||||||
  bool is_donatable() const {
 | 
					  bool is_donatable(int known_instances = 1) const {
 | 
				
			||||||
    return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1);
 | 
					    return array_desc_.use_count() == known_instances &&
 | 
				
			||||||
 | 
					        (array_desc_->data.use_count() == 1);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /** The array's siblings. */
 | 
					  /** The array's siblings. */
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -53,6 +53,10 @@ Stream communication_stream();
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
/* Perform an all reduce sum operation */
 | 
					/* Perform an all reduce sum operation */
 | 
				
			||||||
void all_reduce_sum(Group group, const array& input, array& output);
 | 
					void all_reduce_sum(Group group, const array& input, array& output);
 | 
				
			||||||
 | 
					void all_reduce_sum(
 | 
				
			||||||
 | 
					    Group group,
 | 
				
			||||||
 | 
					    const std::vector<array>& inputs,
 | 
				
			||||||
 | 
					    std::vector<array>& outputs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/* Perform an all reduce sum operation */
 | 
					/* Perform an all reduce sum operation */
 | 
				
			||||||
void all_gather(Group group, const array& input, array& output);
 | 
					void all_gather(Group group, const array& input, array& output);
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -46,7 +46,9 @@ struct MPIWrapper {
 | 
				
			|||||||
    LOAD_SYMBOL(MPI_Comm_split, comm_split);
 | 
					    LOAD_SYMBOL(MPI_Comm_split, comm_split);
 | 
				
			||||||
    LOAD_SYMBOL(MPI_Comm_free, comm_free);
 | 
					    LOAD_SYMBOL(MPI_Comm_free, comm_free);
 | 
				
			||||||
    LOAD_SYMBOL(MPI_Allreduce, all_reduce);
 | 
					    LOAD_SYMBOL(MPI_Allreduce, all_reduce);
 | 
				
			||||||
 | 
					    LOAD_SYMBOL(MPI_Iallreduce, async_all_reduce);
 | 
				
			||||||
    LOAD_SYMBOL(MPI_Allgather, all_gather);
 | 
					    LOAD_SYMBOL(MPI_Allgather, all_gather);
 | 
				
			||||||
 | 
					    LOAD_SYMBOL(MPI_Waitall, waitall);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Objects
 | 
					    // Objects
 | 
				
			||||||
    LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_);
 | 
					    LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_);
 | 
				
			||||||
@@ -130,6 +132,14 @@ struct MPIWrapper {
 | 
				
			|||||||
  int (*finalize)();
 | 
					  int (*finalize)();
 | 
				
			||||||
  int (*rank)(MPI_Comm, int*);
 | 
					  int (*rank)(MPI_Comm, int*);
 | 
				
			||||||
  int (*size)(MPI_Comm, int*);
 | 
					  int (*size)(MPI_Comm, int*);
 | 
				
			||||||
 | 
					  int (*async_all_reduce)(
 | 
				
			||||||
 | 
					      const void*,
 | 
				
			||||||
 | 
					      void*,
 | 
				
			||||||
 | 
					      int,
 | 
				
			||||||
 | 
					      MPI_Datatype,
 | 
				
			||||||
 | 
					      MPI_Op,
 | 
				
			||||||
 | 
					      MPI_Comm,
 | 
				
			||||||
 | 
					      MPI_Request*);
 | 
				
			||||||
  int (*all_reduce)(const void*, void*, int, MPI_Datatype, MPI_Op, MPI_Comm);
 | 
					  int (*all_reduce)(const void*, void*, int, MPI_Datatype, MPI_Op, MPI_Comm);
 | 
				
			||||||
  int (*all_gather)(
 | 
					  int (*all_gather)(
 | 
				
			||||||
      const void*,
 | 
					      const void*,
 | 
				
			||||||
@@ -139,6 +149,7 @@ struct MPIWrapper {
 | 
				
			|||||||
      int,
 | 
					      int,
 | 
				
			||||||
      MPI_Datatype,
 | 
					      MPI_Datatype,
 | 
				
			||||||
      MPI_Comm);
 | 
					      MPI_Comm);
 | 
				
			||||||
 | 
					  int (*waitall)(int, MPI_Request*, MPI_Status*);
 | 
				
			||||||
  int (*comm_split)(MPI_Comm, int, int, MPI_Comm*);
 | 
					  int (*comm_split)(MPI_Comm, int, int, MPI_Comm*);
 | 
				
			||||||
  int (*comm_free)(MPI_Comm*);
 | 
					  int (*comm_free)(MPI_Comm*);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -258,7 +269,8 @@ Stream communication_stream() {
 | 
				
			|||||||
void all_reduce_sum(Group group, const array& input_, array& output) {
 | 
					void all_reduce_sum(Group group, const array& input_, array& output) {
 | 
				
			||||||
  array input = ensure_row_contiguous(input_);
 | 
					  array input = ensure_row_contiguous(input_);
 | 
				
			||||||
  mpi().all_reduce(
 | 
					  mpi().all_reduce(
 | 
				
			||||||
      input.data<void>(),
 | 
					      (input.data<void>() != output.data<void>()) ? input.data<void>()
 | 
				
			||||||
 | 
					                                                  : MPI_IN_PLACE,
 | 
				
			||||||
      output.data<void>(),
 | 
					      output.data<void>(),
 | 
				
			||||||
      input.size(),
 | 
					      input.size(),
 | 
				
			||||||
      mpi().datatype(input),
 | 
					      mpi().datatype(input),
 | 
				
			||||||
@@ -266,6 +278,27 @@ void all_reduce_sum(Group group, const array& input_, array& output) {
 | 
				
			|||||||
      to_comm(group));
 | 
					      to_comm(group));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void all_reduce_sum(
 | 
				
			||||||
 | 
					    Group group,
 | 
				
			||||||
 | 
					    const std::vector<array>& inputs,
 | 
				
			||||||
 | 
					    std::vector<array>& outputs) {
 | 
				
			||||||
 | 
					  std::vector<MPI_Request> requests(inputs.size());
 | 
				
			||||||
 | 
					  std::vector<MPI_Status> statuses(inputs.size());
 | 
				
			||||||
 | 
					  for (int i = 0; i < inputs.size(); i++) {
 | 
				
			||||||
 | 
					    array input = ensure_row_contiguous(inputs[i]);
 | 
				
			||||||
 | 
					    mpi().async_all_reduce(
 | 
				
			||||||
 | 
					        (input.data<void>() != outputs[i].data<void>()) ? input.data<void>()
 | 
				
			||||||
 | 
					                                                        : MPI_IN_PLACE,
 | 
				
			||||||
 | 
					        outputs[i].data<void>(),
 | 
				
			||||||
 | 
					        input.size(),
 | 
				
			||||||
 | 
					        mpi().datatype(input),
 | 
				
			||||||
 | 
					        mpi().op_sum(),
 | 
				
			||||||
 | 
					        to_comm(group),
 | 
				
			||||||
 | 
					        &requests[i]);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  mpi().waitall(requests.size(), &requests[0], &statuses[0]);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void all_gather(Group group, const array& input_, array& output) {
 | 
					void all_gather(Group group, const array& input_, array& output) {
 | 
				
			||||||
  array input = ensure_row_contiguous(input_);
 | 
					  array input = ensure_row_contiguous(input_);
 | 
				
			||||||
  mpi().all_gather(
 | 
					  mpi().all_gather(
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -32,6 +32,10 @@ Stream communication_stream() {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void all_reduce_sum(Group group, const array& input, array& output) {}
 | 
					void all_reduce_sum(Group group, const array& input, array& output) {}
 | 
				
			||||||
 | 
					void all_reduce_sum(
 | 
				
			||||||
 | 
					    Group group,
 | 
				
			||||||
 | 
					    const std::vector<array>& inputs,
 | 
				
			||||||
 | 
					    std::vector<array>& outputs) {}
 | 
				
			||||||
void all_gather(Group group, const array& input, array& output) {}
 | 
					void all_gather(Group group, const array& input, array& output) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace detail
 | 
					} // namespace detail
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -18,17 +18,33 @@ Group to_group(std::optional<Group> group) {
 | 
				
			|||||||
} // namespace
 | 
					} // namespace
 | 
				
			||||||
 | 
					
 | 
				
			||||||
array all_reduce_sum(const array& x, std::optional<Group> group_) {
 | 
					array all_reduce_sum(const array& x, std::optional<Group> group_) {
 | 
				
			||||||
 | 
					  return all_reduce_sum(std::vector<array>{x}, std::move(group_))[0];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::vector<array> all_reduce_sum(
 | 
				
			||||||
 | 
					    const std::vector<array>& xs,
 | 
				
			||||||
 | 
					    std::optional<Group> group_) {
 | 
				
			||||||
  auto group = to_group(group_);
 | 
					  auto group = to_group(group_);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (group.size() == 1) {
 | 
					  if (group.size() == 1) {
 | 
				
			||||||
    return x;
 | 
					    return xs;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return array(
 | 
					  std::vector<std::vector<int>> shapes;
 | 
				
			||||||
      x.shape(),
 | 
					  std::vector<Dtype> dtypes;
 | 
				
			||||||
      x.dtype(),
 | 
					  shapes.reserve(xs.size());
 | 
				
			||||||
 | 
					  dtypes.reserve(xs.size());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (const auto& x : xs) {
 | 
				
			||||||
 | 
					    shapes.push_back(x.shape());
 | 
				
			||||||
 | 
					    dtypes.push_back(x.dtype());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return array::make_arrays(
 | 
				
			||||||
 | 
					      std::move(shapes),
 | 
				
			||||||
 | 
					      std::move(dtypes),
 | 
				
			||||||
      std::make_shared<AllReduce>(group, AllReduce::Sum),
 | 
					      std::make_shared<AllReduce>(group, AllReduce::Sum),
 | 
				
			||||||
      {x});
 | 
					      xs);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
array all_gather(const array& x, std::optional<Group> group_) {
 | 
					array all_gather(const array& x, std::optional<Group> group_) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -9,6 +9,9 @@
 | 
				
			|||||||
namespace mlx::core::distributed {
 | 
					namespace mlx::core::distributed {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
array all_reduce_sum(const array& x, std::optional<Group> group = std::nullopt);
 | 
					array all_reduce_sum(const array& x, std::optional<Group> group = std::nullopt);
 | 
				
			||||||
 | 
					std::vector<array> all_reduce_sum(
 | 
				
			||||||
 | 
					    const std::vector<array>& xs,
 | 
				
			||||||
 | 
					    std::optional<Group> group = std::nullopt);
 | 
				
			||||||
array all_gather(const array& x, std::optional<Group> group = std::nullopt);
 | 
					array all_gather(const array& x, std::optional<Group> group = std::nullopt);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace mlx::core::distributed
 | 
					} // namespace mlx::core::distributed
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -13,17 +13,30 @@ namespace mlx::core::distributed {
 | 
				
			|||||||
void AllReduce::eval_cpu(
 | 
					void AllReduce::eval_cpu(
 | 
				
			||||||
    const std::vector<array>& inputs,
 | 
					    const std::vector<array>& inputs,
 | 
				
			||||||
    std::vector<array>& outputs) {
 | 
					    std::vector<array>& outputs) {
 | 
				
			||||||
  assert(inputs.size() == 1);
 | 
					  for (int i = 0; i < inputs.size(); i++) {
 | 
				
			||||||
  assert(outputs.size() == 1);
 | 
					    if (inputs[i].is_donatable(outputs.size())) {
 | 
				
			||||||
 | 
					      outputs[i].copy_shared_buffer(inputs[i]);
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      outputs[i].set_data(allocator::malloc_or_wait(outputs[i].nbytes()));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
 | 
					  if (inputs.size() == 1) {
 | 
				
			||||||
 | 
					    switch (reduce_type_) {
 | 
				
			||||||
  switch (reduce_type_) {
 | 
					      case Sum:
 | 
				
			||||||
    case Sum:
 | 
					        distributed::detail::all_reduce_sum(group(), inputs[0], outputs[0]);
 | 
				
			||||||
      distributed::detail::all_reduce_sum(group(), inputs[0], outputs[0]);
 | 
					        break;
 | 
				
			||||||
      break;
 | 
					      default:
 | 
				
			||||||
    default:
 | 
					        throw std::runtime_error("Only all reduce sum is supported for now");
 | 
				
			||||||
      throw std::runtime_error("Only all reduce sum is supported for now");
 | 
					    }
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    switch (reduce_type_) {
 | 
				
			||||||
 | 
					      case Sum:
 | 
				
			||||||
 | 
					        distributed::detail::all_reduce_sum(group(), inputs, outputs);
 | 
				
			||||||
 | 
					        break;
 | 
				
			||||||
 | 
					      default:
 | 
				
			||||||
 | 
					        throw std::runtime_error("Only all reduce sum is supported for now");
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,6 +6,7 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
#include "mlx/distributed/distributed.h"
 | 
					#include "mlx/distributed/distributed.h"
 | 
				
			||||||
#include "mlx/distributed/ops.h"
 | 
					#include "mlx/distributed/ops.h"
 | 
				
			||||||
 | 
					#include "python/src/trees.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace nb = nanobind;
 | 
					namespace nb = nanobind;
 | 
				
			||||||
using namespace nb::literals;
 | 
					using namespace nb::literals;
 | 
				
			||||||
@@ -60,19 +61,30 @@ void init_distributed(nb::module_& parent_module) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  m.def(
 | 
					  m.def(
 | 
				
			||||||
      "all_reduce_sum",
 | 
					      "all_reduce_sum",
 | 
				
			||||||
      &distributed::all_reduce_sum,
 | 
					      [](const nb::args& args, std::optional<distributed::Group> group) {
 | 
				
			||||||
      "x"_a,
 | 
					        auto [xs, structure] =
 | 
				
			||||||
 | 
					            tree_flatten_with_structure((args.size() == 1) ? args[0] : args);
 | 
				
			||||||
 | 
					        auto ys = distributed::all_reduce_sum(xs, group);
 | 
				
			||||||
 | 
					        return tree_unflatten_from_structure(structure, ys);
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      nb::arg(),
 | 
				
			||||||
      nb::kw_only(),
 | 
					      nb::kw_only(),
 | 
				
			||||||
      "group"_a = nb::none(),
 | 
					      "group"_a = nb::none(),
 | 
				
			||||||
      nb::sig(
 | 
					      nb::sig(
 | 
				
			||||||
          "def all_reduce_sum(x: array, *, group: Optional[Group] = None) -> array"),
 | 
					          "def all_reduce_sum(*args, group: Optional[Group] = None) -> array"),
 | 
				
			||||||
      R"pbdoc(
 | 
					      R"pbdoc(
 | 
				
			||||||
        All reduce sum.
 | 
					        All reduce sum.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Sum the ``x`` arrays from all processes in the group.
 | 
					        Sum the passed :class:`array` or the tree of :class:`array` from all
 | 
				
			||||||
 | 
					        processes in the group.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        .. note::
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          In order for the reduction to work, the iteration order of the passed
 | 
				
			||||||
 | 
					          trees of arrays must be the same across machines.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
          x (array): Input array.
 | 
					          *args (arrays or trees of arrays): Input arrays.
 | 
				
			||||||
          group (Group): The group of processes that will participate in the
 | 
					          group (Group): The group of processes that will participate in the
 | 
				
			||||||
            reduction. If set to ``None`` the global group is used. Default:
 | 
					            reduction. If set to ``None`` the global group is used. Default:
 | 
				
			||||||
            ``None``.
 | 
					            ``None``.
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user