mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Adds send/recv ops in distributed (#1366)
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							1d94ac3f90
						
					
				
				
					commit
					cdb59faea6
				
			| @@ -17,3 +17,6 @@ made available. | ||||
|     init | ||||
|     all_sum | ||||
|     all_gather | ||||
|     send | ||||
|     recv | ||||
|     recv_like | ||||
|   | ||||
| @@ -10,7 +10,7 @@ | ||||
|  | ||||
| namespace mlx::core::distributed { | ||||
|  | ||||
| void signal_and_wait(const array& in, const array& out, const Stream s) { | ||||
| void signal_and_wait(const array& in, const array& out, const Stream& s) { | ||||
|   auto& d = metal::device(s.device); | ||||
|   d.end_encoding(s.index); | ||||
|   auto command_buffer = d.get_command_buffer(s.index); | ||||
| @@ -81,4 +81,62 @@ void AllGather::eval_gpu( | ||||
|   signal_and_wait(in, out, stream()); | ||||
| } | ||||
|  | ||||
| void Send::eval_gpu( | ||||
|     const std::vector<array>& inputs, | ||||
|     std::vector<array>& outputs) { | ||||
|   assert(inputs.size() == 1); | ||||
|   assert(outputs.size() == 1); | ||||
|  | ||||
|   auto& in = inputs[0]; | ||||
|   auto& out = outputs[0]; | ||||
|  | ||||
|   // Schedule an async send on the comm stream | ||||
|   auto task = [in = in, out = out, group = group(), dst = dst_]() mutable { | ||||
|     if (in.event().valid()) { | ||||
|       in.event().wait(); | ||||
|     } | ||||
|     distributed::detail::send(group, in, dst); | ||||
|     out.event().signal(); | ||||
|   }; | ||||
|   scheduler::enqueue(detail::communication_stream(), std::move(task)); | ||||
|  | ||||
|   // Encode a signal event for the input but not a wait since we don't need to | ||||
|   // wait on the output. | ||||
|   auto& s = stream(); | ||||
|   auto& d = metal::device(s.device); | ||||
|   d.end_encoding(s.index); | ||||
|   auto command_buffer = d.get_command_buffer(s.index); | ||||
|   if (in.event().valid()) { | ||||
|     command_buffer->encodeSignalEvent( | ||||
|         static_cast<MTL::Event*>(in.event().raw_event().get()), | ||||
|         in.event().value()); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void Recv::eval_gpu( | ||||
|     const std::vector<array>& inputs, | ||||
|     std::vector<array>& outputs) { | ||||
|   assert(inputs.size() == 0); | ||||
|   assert(outputs.size() == 1); | ||||
|  | ||||
|   auto& out = outputs[0]; | ||||
|  | ||||
|   out.set_data(allocator::malloc_or_wait(out.nbytes())); | ||||
|  | ||||
|   // Schedule an async recv on the comm stream | ||||
|   auto task = [out = out, group = group(), src = src_]() mutable { | ||||
|     distributed::detail::recv(group, out, src); | ||||
|     out.event().signal(); | ||||
|   }; | ||||
|   scheduler::enqueue(detail::communication_stream(), std::move(task)); | ||||
|  | ||||
|   // Encode a wait event as there is no input for the recv to encode a signal. | ||||
|   auto& s = stream(); | ||||
|   auto& d = metal::device(s.device); | ||||
|   auto command_buffer = d.get_command_buffer(s.index); | ||||
|   command_buffer->encodeWait( | ||||
|       static_cast<MTL::Event*>(out.event().raw_event().get()), | ||||
|       out.event().value()); | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core::distributed | ||||
|   | ||||
| @@ -126,6 +126,8 @@ NO_GPU_MULTI(CustomKernel) | ||||
| namespace distributed { | ||||
| NO_GPU_MULTI(AllReduce) | ||||
| NO_GPU_MULTI(AllGather) | ||||
| NO_GPU_MULTI(Send) | ||||
| NO_GPU_MULTI(Recv) | ||||
| } // namespace distributed | ||||
|  | ||||
| } // namespace mlx::core | ||||
|   | ||||
| @@ -50,17 +50,4 @@ struct Group { | ||||
|  */ | ||||
| Group init(bool strict = false); | ||||
|  | ||||
| namespace detail { | ||||
|  | ||||
| /* Return the communication stream. */ | ||||
| Stream communication_stream(); | ||||
|  | ||||
| /* Perform an all reduce sum operation */ | ||||
| void all_sum(Group group, const array& input, array& output); | ||||
|  | ||||
| /* Perform an all reduce sum operation */ | ||||
| void all_gather(Group group, const array& input, array& output); | ||||
|  | ||||
| } // namespace detail | ||||
|  | ||||
| } // namespace mlx::core::distributed | ||||
|   | ||||
| @@ -12,7 +12,13 @@ Stream communication_stream(); | ||||
| /* Perform an all reduce sum operation */ | ||||
| void all_sum(Group group, const array& input, array& output); | ||||
|  | ||||
| /* Perform an all reduce sum operation */ | ||||
| /* Perform an all gather operation */ | ||||
| void all_gather(Group group, const array& input, array& output); | ||||
|  | ||||
| /** Send an array to the dst rank */ | ||||
| void send(Group group, const array& input, int dst); | ||||
|  | ||||
| /** Recv an array from the src rank */ | ||||
| void recv(Group group, array& out, int src); | ||||
|  | ||||
| } // namespace mlx::core::distributed::detail | ||||
|   | ||||
| @@ -48,6 +48,8 @@ struct MPIWrapper { | ||||
|     LOAD_SYMBOL(MPI_Comm_free, comm_free); | ||||
|     LOAD_SYMBOL(MPI_Allreduce, all_reduce); | ||||
|     LOAD_SYMBOL(MPI_Allgather, all_gather); | ||||
|     LOAD_SYMBOL(MPI_Send, send); | ||||
|     LOAD_SYMBOL(MPI_Recv, recv); | ||||
|  | ||||
|     // Objects | ||||
|     LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_); | ||||
| @@ -142,6 +144,8 @@ struct MPIWrapper { | ||||
|       MPI_Comm); | ||||
|   int (*comm_split)(MPI_Comm, int, int, MPI_Comm*); | ||||
|   int (*comm_free)(MPI_Comm*); | ||||
|   int (*send)(const void*, int, MPI_Datatype, int, int, MPI_Comm); | ||||
|   int (*recv)(void*, int, MPI_Datatype, int, int, MPI_Comm, MPI_Status*); | ||||
|  | ||||
|   // Objects | ||||
|   MPI_Comm comm_world_; | ||||
| @@ -285,6 +289,29 @@ void all_gather(Group group, const array& input_, array& output) { | ||||
|       to_comm(group)); | ||||
| } | ||||
|  | ||||
| void send(Group group, const array& input_, int dst) { | ||||
|   array input = ensure_row_contiguous(input_); | ||||
|   mpi().send( | ||||
|       input.data<void>(), | ||||
|       input.size(), | ||||
|       mpi().datatype(input), | ||||
|       dst, | ||||
|       0, | ||||
|       to_comm(group)); | ||||
| } | ||||
|  | ||||
| void recv(Group group, array& out, int src) { | ||||
|   MPI_Status status; | ||||
|   mpi().recv( | ||||
|       out.data<void>(), | ||||
|       out.size(), | ||||
|       mpi().datatype(out), | ||||
|       src, | ||||
|       MPI_ANY_TAG, | ||||
|       to_comm(group), | ||||
|       &status); | ||||
| } | ||||
|  | ||||
| } // namespace detail | ||||
|  | ||||
| } // namespace mlx::core::distributed | ||||
|   | ||||
| @@ -34,6 +34,8 @@ Stream communication_stream() { | ||||
|  | ||||
| void all_sum(Group group, const array& input, array& output) {} | ||||
| void all_gather(Group group, const array& input, array& output) {} | ||||
| void send(Group group, const array& input, int dst) {} | ||||
| void recv(Group group, array& out, int src) {} | ||||
|  | ||||
| } // namespace detail | ||||
|  | ||||
|   | ||||
| @@ -1,5 +1,7 @@ | ||||
| // Copyright © 2024 Apple Inc. | ||||
|  | ||||
| #include <sstream> | ||||
|  | ||||
| #include "mlx/distributed/ops.h" | ||||
| #include "mlx/distributed/primitives.h" | ||||
|  | ||||
| @@ -57,4 +59,59 @@ array all_gather( | ||||
|       {x}); | ||||
| } | ||||
|  | ||||
| array send( | ||||
|     const array& x, | ||||
|     int dst, | ||||
|     std::optional<Group> group_ /* = std::nullopt */, | ||||
|     StreamOrDevice s /* = {} */) { | ||||
|   auto group = to_group(group_); | ||||
|  | ||||
|   if (group.size() == 1) { | ||||
|     throw std::invalid_argument("Cannot send to a singleton group"); | ||||
|   } | ||||
|  | ||||
|   if (dst < 0 || dst >= group.size()) { | ||||
|     std::ostringstream msg; | ||||
|     msg << "Invalid destination=" << dst << " for a group of size " | ||||
|         << group.size(); | ||||
|     throw std::invalid_argument(msg.str()); | ||||
|   } | ||||
|  | ||||
|   return array( | ||||
|       {0}, int32, std::make_shared<Send>(to_stream(s), group, dst), {x}); | ||||
| } | ||||
|  | ||||
| array recv( | ||||
|     std::vector<int> shape, | ||||
|     Dtype dtype, | ||||
|     int src, | ||||
|     std::optional<Group> group_ /* = std::nullopt */, | ||||
|     StreamOrDevice s /* = {} */) { | ||||
|   auto group = to_group(group_); | ||||
|  | ||||
|   if (group.size() == 1) { | ||||
|     throw std::invalid_argument("Cannot recv from a singleton group"); | ||||
|   } | ||||
|  | ||||
|   if (src < 0 || src >= group.size()) { | ||||
|     std::ostringstream msg; | ||||
|     msg << "Invalid source=" << src << " for a group of size " << group.size(); | ||||
|     throw std::invalid_argument(msg.str()); | ||||
|   } | ||||
|  | ||||
|   return array( | ||||
|       std::move(shape), | ||||
|       std::move(dtype), | ||||
|       std::make_shared<Recv>(to_stream(s), group, src), | ||||
|       std::vector<array>{}); | ||||
| } | ||||
|  | ||||
| array recv_like( | ||||
|     const array& x, | ||||
|     int src, | ||||
|     std::optional<Group> group_ /* = std::nullopt */, | ||||
|     StreamOrDevice s /* = {} */) { | ||||
|   return recv(x.shape(), x.dtype(), src, group_, s); | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core::distributed | ||||
|   | ||||
| @@ -13,9 +13,29 @@ array all_sum( | ||||
|     const array& x, | ||||
|     std::optional<Group> group = std::nullopt, | ||||
|     StreamOrDevice s = {}); | ||||
|  | ||||
| array all_gather( | ||||
|     const array& x, | ||||
|     std::optional<Group> group = std::nullopt, | ||||
|     StreamOrDevice S = {}); | ||||
|  | ||||
| array send( | ||||
|     const array& x, | ||||
|     int dst, | ||||
|     std::optional<Group> group = std::nullopt, | ||||
|     StreamOrDevice s = {}); | ||||
|  | ||||
| array recv( | ||||
|     std::vector<int> shape, | ||||
|     Dtype dtype, | ||||
|     int src, | ||||
|     std::optional<Group> group = std::nullopt, | ||||
|     StreamOrDevice s = {}); | ||||
|  | ||||
| array recv_like( | ||||
|     const array& x, | ||||
|     int src, | ||||
|     std::optional<Group> group = std::nullopt, | ||||
|     StreamOrDevice s = {}); | ||||
|  | ||||
| } // namespace mlx::core::distributed | ||||
|   | ||||
| @@ -35,7 +35,7 @@ std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap( | ||||
|     const std::vector<int>& axes) { | ||||
|   switch (reduce_type_) { | ||||
|     case Sum: | ||||
|       return {{all_sum(inputs[0], group())}, axes}; | ||||
|       return {{all_sum(inputs[0], group(), stream())}, axes}; | ||||
|     default: | ||||
|       throw std::runtime_error("Only all reduce sum is supported for now"); | ||||
|   } | ||||
| @@ -47,7 +47,7 @@ std::vector<array> AllReduce::jvp( | ||||
|     const std::vector<int>& argnums) { | ||||
|   switch (reduce_type_) { | ||||
|     case Sum: | ||||
|       return {all_sum(tangents[0], group())}; | ||||
|       return {all_sum(tangents[0], group(), stream())}; | ||||
|     default: | ||||
|       throw std::runtime_error("Only all reduce sum is supported for now"); | ||||
|   } | ||||
| @@ -75,14 +75,14 @@ void AllGather::eval_cpu( | ||||
| std::pair<std::vector<array>, std::vector<int>> AllGather::vmap( | ||||
|     const std::vector<array>& inputs, | ||||
|     const std::vector<int>& axes) { | ||||
|   return {{all_gather(inputs[0], group())}, axes}; | ||||
|   return {{all_gather(inputs[0], group(), stream())}, axes}; | ||||
| } | ||||
|  | ||||
| std::vector<array> AllGather::jvp( | ||||
|     const std::vector<array>& primals, | ||||
|     const std::vector<array>& tangents, | ||||
|     const std::vector<int>& argnums) { | ||||
|   return {all_gather(tangents[0], group())}; | ||||
|   return {all_gather(tangents[0], group(), stream())}; | ||||
| } | ||||
|  | ||||
| std::vector<array> AllGather::vjp( | ||||
| @@ -98,4 +98,29 @@ std::vector<array> AllGather::vjp( | ||||
|   return {slice(cotangents[0], starts, stops)}; | ||||
| } | ||||
|  | ||||
| void Send::eval_cpu( | ||||
|     const std::vector<array>& inputs, | ||||
|     std::vector<array>& outputs) { | ||||
|   assert(inputs.size() == 1); | ||||
|   assert(outputs.size() == 1); | ||||
|  | ||||
|   distributed::detail::send(group(), inputs[0], dst_); | ||||
| } | ||||
|  | ||||
| std::pair<std::vector<array>, std::vector<int>> Send::vmap( | ||||
|     const std::vector<array>& inputs, | ||||
|     const std::vector<int>& axes) { | ||||
|   return {{send(inputs[0], dst_, group(), stream())}, axes}; | ||||
| } | ||||
|  | ||||
| void Recv::eval_cpu( | ||||
|     const std::vector<array>& inputs, | ||||
|     std::vector<array>& outputs) { | ||||
|   assert(inputs.size() == 0); | ||||
|   assert(outputs.size() == 1); | ||||
|  | ||||
|   outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes())); | ||||
|   distributed::detail::recv(group(), outputs[0], src_); | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core::distributed | ||||
|   | ||||
| @@ -97,4 +97,39 @@ class AllGather : public DistPrimitive { | ||||
|   DEFINE_PRINT(AllGather); | ||||
| }; | ||||
|  | ||||
| class Send : public DistPrimitive { | ||||
|  public: | ||||
|   Send(Stream stream, Group group, int dst) | ||||
|       : DistPrimitive(stream, group), dst_(dst) {} | ||||
|  | ||||
|   void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) | ||||
|       override; | ||||
|   void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs) | ||||
|       override; | ||||
|   std::pair<std::vector<array>, std::vector<int>> vmap( | ||||
|       const std::vector<array>& inputs, | ||||
|       const std::vector<int>& axes) override; | ||||
|  | ||||
|   DEFINE_PRINT(Send); | ||||
|  | ||||
|  private: | ||||
|   int dst_; | ||||
| }; | ||||
|  | ||||
| class Recv : public DistPrimitive { | ||||
|  public: | ||||
|   Recv(Stream stream, Group group, int src) | ||||
|       : DistPrimitive(stream, group), src_(src) {} | ||||
|  | ||||
|   void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) | ||||
|       override; | ||||
|   void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs) | ||||
|       override; | ||||
|  | ||||
|   DEFINE_PRINT(Recv); | ||||
|  | ||||
|  private: | ||||
|   int src_; | ||||
| }; | ||||
|  | ||||
| } // namespace mlx::core::distributed | ||||
|   | ||||
| @@ -4,6 +4,7 @@ | ||||
| #include <nanobind/stl/optional.h> | ||||
| #include <nanobind/stl/shared_ptr.h> | ||||
| #include <nanobind/stl/variant.h> | ||||
| #include <nanobind/stl/vector.h> | ||||
|  | ||||
| #include "mlx/distributed/distributed.h" | ||||
| #include "mlx/distributed/ops.h" | ||||
| @@ -121,4 +122,90 @@ void init_distributed(nb::module_& parent_module) { | ||||
|         Returns: | ||||
|           array: The concatenation of all ``x`` arrays. | ||||
|       )pbdoc"); | ||||
|  | ||||
|   m.def( | ||||
|       "send", | ||||
|       &distributed::send, | ||||
|       "x"_a, | ||||
|       "dst"_a, | ||||
|       nb::kw_only(), | ||||
|       "group"_a = nb::none(), | ||||
|       "stream"_a = nb::none(), | ||||
|       nb::sig( | ||||
|           "def send(x: array, dst: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"), | ||||
|       R"pbdoc( | ||||
|         Send an array from the current process to the process that has rank | ||||
|         ``dst`` in the group. | ||||
|  | ||||
|         Args: | ||||
|           x (array): Input array. | ||||
|           dst (int): Rank of the destination process in the group. | ||||
|           group (Group): The group of processes that will participate in the | ||||
|             sned. If set to ``None`` the global group is used. Default: | ||||
|             ``None``. | ||||
|           stream (Stream, optional): Stream or device. Defaults to ``None`` | ||||
|             in which case the default stream of the default device is used. | ||||
|  | ||||
|         Returns: | ||||
|           array: An empty array which when evaluated the send is performed. | ||||
|       )pbdoc"); | ||||
|  | ||||
|   m.def( | ||||
|       "recv", | ||||
|       &distributed::recv, | ||||
|       "shape"_a, | ||||
|       "dtype"_a, | ||||
|       "src"_a, | ||||
|       nb::kw_only(), | ||||
|       "group"_a = nb::none(), | ||||
|       "stream"_a = nb::none(), | ||||
|       nb::sig( | ||||
|           "def recv(shape: Sequence[int], dtype: Dtype, src: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"), | ||||
|       R"pbdoc( | ||||
|         Recv an array with shape ``shape`` and dtype ``dtype`` from process | ||||
|         with rank ``src``. | ||||
|  | ||||
|         Args: | ||||
|           shape (Tuple[int]): The shape of the array we are receiving. | ||||
|           dtype (Dtype): The data type of the array we are receiving. | ||||
|           src (int): Rank of the source process in the group. | ||||
|           group (Group): The group of processes that will participate in the | ||||
|             recv. If set to ``None`` the global group is used. Default: | ||||
|             ``None``. | ||||
|           stream (Stream, optional): Stream or device. Defaults to ``None`` | ||||
|             in which case the default stream of the default device is used. | ||||
|  | ||||
|         Returns: | ||||
|           array: The array that was received from ``src``. | ||||
|       )pbdoc"); | ||||
|  | ||||
|   m.def( | ||||
|       "recv_like", | ||||
|       &distributed::recv_like, | ||||
|       "x"_a, | ||||
|       "src"_a, | ||||
|       nb::kw_only(), | ||||
|       "group"_a = nb::none(), | ||||
|       "stream"_a = nb::none(), | ||||
|       nb::sig( | ||||
|           "def recv_like(x: array, src: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"), | ||||
|       R"pbdoc( | ||||
|         Recv an array with shape and type like ``x`` from process with rank | ||||
|         ``src``. | ||||
|  | ||||
|         It is equivalent to calling ``mx.distributed.recv(x.shape, x.dtype, src)``. | ||||
|  | ||||
|         Args: | ||||
|           x (array): An array defining the shape and dtype of the array we are | ||||
|             receiving. | ||||
|           src (int): Rank of the source process in the group. | ||||
|           group (Group): The group of processes that will participate in the | ||||
|             recv. If set to ``None`` the global group is used. Default: | ||||
|             ``None``. | ||||
|           stream (Stream, optional): Stream or device. Defaults to ``None`` | ||||
|             in which case the default stream of the default device is used. | ||||
|  | ||||
|         Returns: | ||||
|           array: The array that was received from ``src``. | ||||
|       )pbdoc"); | ||||
| } | ||||
|   | ||||
| @@ -93,6 +93,23 @@ class TestDistributed(mlx_tests.MLXTestCase): | ||||
|  | ||||
|         self.assertTrue(mx.all(z == z_target)) | ||||
|  | ||||
|     def test_send_recv(self): | ||||
|         world = mx.distributed.init() | ||||
|         pairs = world.split(world.rank() // 2) | ||||
|         neighbor = (pairs.rank() + 1) % 2 | ||||
|         send = pairs.rank() == 0 | ||||
|  | ||||
|         x = mx.ones(10) | ||||
|         for i in range(10): | ||||
|             if send: | ||||
|                 mx.eval(mx.distributed.send(2 * x, neighbor, group=pairs)) | ||||
|             else: | ||||
|                 x = mx.distributed.recv_like(x, neighbor, group=pairs) | ||||
|                 mx.eval(x) | ||||
|             send = not send | ||||
|  | ||||
|         self.assertTrue(mx.all(x == (1024 if pairs.rank() == 0 else 512))) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user