mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Add async all reduce and donation
This commit is contained in:
		| @@ -252,8 +252,9 @@ class array { | ||||
|   } | ||||
|  | ||||
|   /** True indicates the arrays buffer is safe to reuse */ | ||||
|   bool is_donatable() const { | ||||
|     return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1); | ||||
|   bool is_donatable(int known_instances = 1) const { | ||||
|     return array_desc_.use_count() == known_instances && | ||||
|         (array_desc_->data.use_count() == 1); | ||||
|   } | ||||
|  | ||||
|   /** The array's siblings. */ | ||||
|   | ||||
| @@ -53,6 +53,10 @@ Stream communication_stream(); | ||||
|  | ||||
| /* Perform an all reduce sum operation */ | ||||
| void all_reduce_sum(Group group, const array& input, array& output); | ||||
| void all_reduce_sum( | ||||
|     Group group, | ||||
|     const std::vector<array>& inputs, | ||||
|     std::vector<array>& outputs); | ||||
|  | ||||
| /* Perform an all reduce sum operation */ | ||||
| void all_gather(Group group, const array& input, array& output); | ||||
|   | ||||
| @@ -46,7 +46,9 @@ struct MPIWrapper { | ||||
|     LOAD_SYMBOL(MPI_Comm_split, comm_split); | ||||
|     LOAD_SYMBOL(MPI_Comm_free, comm_free); | ||||
|     LOAD_SYMBOL(MPI_Allreduce, all_reduce); | ||||
|     LOAD_SYMBOL(MPI_Iallreduce, async_all_reduce); | ||||
|     LOAD_SYMBOL(MPI_Allgather, all_gather); | ||||
|     LOAD_SYMBOL(MPI_Waitall, waitall); | ||||
|  | ||||
|     // Objects | ||||
|     LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_); | ||||
| @@ -130,6 +132,14 @@ struct MPIWrapper { | ||||
|   int (*finalize)(); | ||||
|   int (*rank)(MPI_Comm, int*); | ||||
|   int (*size)(MPI_Comm, int*); | ||||
|   int (*async_all_reduce)( | ||||
|       const void*, | ||||
|       void*, | ||||
|       int, | ||||
|       MPI_Datatype, | ||||
|       MPI_Op, | ||||
|       MPI_Comm, | ||||
|       MPI_Request*); | ||||
|   int (*all_reduce)(const void*, void*, int, MPI_Datatype, MPI_Op, MPI_Comm); | ||||
|   int (*all_gather)( | ||||
|       const void*, | ||||
| @@ -139,6 +149,7 @@ struct MPIWrapper { | ||||
|       int, | ||||
|       MPI_Datatype, | ||||
|       MPI_Comm); | ||||
|   int (*waitall)(int, MPI_Request*, MPI_Status*); | ||||
|   int (*comm_split)(MPI_Comm, int, int, MPI_Comm*); | ||||
|   int (*comm_free)(MPI_Comm*); | ||||
|  | ||||
| @@ -258,7 +269,8 @@ Stream communication_stream() { | ||||
| void all_reduce_sum(Group group, const array& input_, array& output) { | ||||
|   array input = ensure_row_contiguous(input_); | ||||
|   mpi().all_reduce( | ||||
|       input.data<void>(), | ||||
|       (input.data<void>() != output.data<void>()) ? input.data<void>() | ||||
|                                                   : MPI_IN_PLACE, | ||||
|       output.data<void>(), | ||||
|       input.size(), | ||||
|       mpi().datatype(input), | ||||
| @@ -266,6 +278,27 @@ void all_reduce_sum(Group group, const array& input_, array& output) { | ||||
|       to_comm(group)); | ||||
| } | ||||
|  | ||||
| void all_reduce_sum( | ||||
|     Group group, | ||||
|     const std::vector<array>& inputs, | ||||
|     std::vector<array>& outputs) { | ||||
|   std::vector<MPI_Request> requests(inputs.size()); | ||||
|   std::vector<MPI_Status> statuses(inputs.size()); | ||||
|   for (int i = 0; i < inputs.size(); i++) { | ||||
|     array input = ensure_row_contiguous(inputs[i]); | ||||
|     mpi().async_all_reduce( | ||||
|         (input.data<void>() != outputs[i].data<void>()) ? input.data<void>() | ||||
|                                                         : MPI_IN_PLACE, | ||||
|         outputs[i].data<void>(), | ||||
|         input.size(), | ||||
|         mpi().datatype(input), | ||||
|         mpi().op_sum(), | ||||
|         to_comm(group), | ||||
|         &requests[i]); | ||||
|   } | ||||
|   mpi().waitall(requests.size(), &requests[0], &statuses[0]); | ||||
| } | ||||
|  | ||||
| void all_gather(Group group, const array& input_, array& output) { | ||||
|   array input = ensure_row_contiguous(input_); | ||||
|   mpi().all_gather( | ||||
|   | ||||
| @@ -32,6 +32,10 @@ Stream communication_stream() { | ||||
| } | ||||
|  | ||||
| void all_reduce_sum(Group group, const array& input, array& output) {} | ||||
| void all_reduce_sum( | ||||
|     Group group, | ||||
|     const std::vector<array>& inputs, | ||||
|     std::vector<array>& outputs) {} | ||||
| void all_gather(Group group, const array& input, array& output) {} | ||||
|  | ||||
| } // namespace detail | ||||
|   | ||||
| @@ -18,17 +18,33 @@ Group to_group(std::optional<Group> group) { | ||||
| } // namespace | ||||
|  | ||||
| array all_reduce_sum(const array& x, std::optional<Group> group_) { | ||||
|   return all_reduce_sum(std::vector<array>{x}, std::move(group_))[0]; | ||||
| } | ||||
|  | ||||
| std::vector<array> all_reduce_sum( | ||||
|     const std::vector<array>& xs, | ||||
|     std::optional<Group> group_) { | ||||
|   auto group = to_group(group_); | ||||
|  | ||||
|   if (group.size() == 1) { | ||||
|     return x; | ||||
|     return xs; | ||||
|   } | ||||
|  | ||||
|   return array( | ||||
|       x.shape(), | ||||
|       x.dtype(), | ||||
|   std::vector<std::vector<int>> shapes; | ||||
|   std::vector<Dtype> dtypes; | ||||
|   shapes.reserve(xs.size()); | ||||
|   dtypes.reserve(xs.size()); | ||||
|  | ||||
|   for (const auto& x : xs) { | ||||
|     shapes.push_back(x.shape()); | ||||
|     dtypes.push_back(x.dtype()); | ||||
|   } | ||||
|  | ||||
|   return array::make_arrays( | ||||
|       std::move(shapes), | ||||
|       std::move(dtypes), | ||||
|       std::make_shared<AllReduce>(group, AllReduce::Sum), | ||||
|       {x}); | ||||
|       xs); | ||||
| } | ||||
|  | ||||
| array all_gather(const array& x, std::optional<Group> group_) { | ||||
|   | ||||
| @@ -9,6 +9,9 @@ | ||||
| namespace mlx::core::distributed { | ||||
|  | ||||
| array all_reduce_sum(const array& x, std::optional<Group> group = std::nullopt); | ||||
| std::vector<array> all_reduce_sum( | ||||
|     const std::vector<array>& xs, | ||||
|     std::optional<Group> group = std::nullopt); | ||||
| array all_gather(const array& x, std::optional<Group> group = std::nullopt); | ||||
|  | ||||
| } // namespace mlx::core::distributed | ||||
|   | ||||
| @@ -13,17 +13,30 @@ namespace mlx::core::distributed { | ||||
| void AllReduce::eval_cpu( | ||||
|     const std::vector<array>& inputs, | ||||
|     std::vector<array>& outputs) { | ||||
|   assert(inputs.size() == 1); | ||||
|   assert(outputs.size() == 1); | ||||
|   for (int i = 0; i < inputs.size(); i++) { | ||||
|     if (inputs[i].is_donatable(outputs.size())) { | ||||
|       outputs[i].copy_shared_buffer(inputs[i]); | ||||
|     } else { | ||||
|       outputs[i].set_data(allocator::malloc_or_wait(outputs[i].nbytes())); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes())); | ||||
|  | ||||
|   switch (reduce_type_) { | ||||
|     case Sum: | ||||
|       distributed::detail::all_reduce_sum(group(), inputs[0], outputs[0]); | ||||
|       break; | ||||
|     default: | ||||
|       throw std::runtime_error("Only all reduce sum is supported for now"); | ||||
|   if (inputs.size() == 1) { | ||||
|     switch (reduce_type_) { | ||||
|       case Sum: | ||||
|         distributed::detail::all_reduce_sum(group(), inputs[0], outputs[0]); | ||||
|         break; | ||||
|       default: | ||||
|         throw std::runtime_error("Only all reduce sum is supported for now"); | ||||
|     } | ||||
|   } else { | ||||
|     switch (reduce_type_) { | ||||
|       case Sum: | ||||
|         distributed::detail::all_reduce_sum(group(), inputs, outputs); | ||||
|         break; | ||||
|       default: | ||||
|         throw std::runtime_error("Only all reduce sum is supported for now"); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -6,6 +6,7 @@ | ||||
|  | ||||
| #include "mlx/distributed/distributed.h" | ||||
| #include "mlx/distributed/ops.h" | ||||
| #include "python/src/trees.h" | ||||
|  | ||||
| namespace nb = nanobind; | ||||
| using namespace nb::literals; | ||||
| @@ -60,19 +61,30 @@ void init_distributed(nb::module_& parent_module) { | ||||
|  | ||||
|   m.def( | ||||
|       "all_reduce_sum", | ||||
|       &distributed::all_reduce_sum, | ||||
|       "x"_a, | ||||
|       [](const nb::args& args, std::optional<distributed::Group> group) { | ||||
|         auto [xs, structure] = | ||||
|             tree_flatten_with_structure((args.size() == 1) ? args[0] : args); | ||||
|         auto ys = distributed::all_reduce_sum(xs, group); | ||||
|         return tree_unflatten_from_structure(structure, ys); | ||||
|       }, | ||||
|       nb::arg(), | ||||
|       nb::kw_only(), | ||||
|       "group"_a = nb::none(), | ||||
|       nb::sig( | ||||
|           "def all_reduce_sum(x: array, *, group: Optional[Group] = None) -> array"), | ||||
|           "def all_reduce_sum(*args, group: Optional[Group] = None) -> array"), | ||||
|       R"pbdoc( | ||||
|         All reduce sum. | ||||
|  | ||||
|         Sum the ``x`` arrays from all processes in the group. | ||||
|         Sum the passed :class:`array` or the tree of :class:`array` from all | ||||
|         processes in the group. | ||||
|  | ||||
|         .. note:: | ||||
|  | ||||
|           In order for the reduction to work, the iteration order of the passed | ||||
|           trees of arrays must be the same across machines. | ||||
|  | ||||
|         Args: | ||||
|           x (array): Input array. | ||||
|           *args (arrays or trees of arrays): Input arrays. | ||||
|           group (Group): The group of processes that will participate in the | ||||
|             reduction. If set to ``None`` the global group is used. Default: | ||||
|             ``None``. | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Angelos Katharopoulos
					Angelos Katharopoulos