mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Adds send/recv ops in distributed (#1366)
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							1d94ac3f90
						
					
				
				
					commit
					cdb59faea6
				
			@@ -17,3 +17,6 @@ made available.
 | 
				
			|||||||
    init
 | 
					    init
 | 
				
			||||||
    all_sum
 | 
					    all_sum
 | 
				
			||||||
    all_gather
 | 
					    all_gather
 | 
				
			||||||
 | 
					    send
 | 
				
			||||||
 | 
					    recv
 | 
				
			||||||
 | 
					    recv_like
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -10,7 +10,7 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
namespace mlx::core::distributed {
 | 
					namespace mlx::core::distributed {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void signal_and_wait(const array& in, const array& out, const Stream s) {
 | 
					void signal_and_wait(const array& in, const array& out, const Stream& s) {
 | 
				
			||||||
  auto& d = metal::device(s.device);
 | 
					  auto& d = metal::device(s.device);
 | 
				
			||||||
  d.end_encoding(s.index);
 | 
					  d.end_encoding(s.index);
 | 
				
			||||||
  auto command_buffer = d.get_command_buffer(s.index);
 | 
					  auto command_buffer = d.get_command_buffer(s.index);
 | 
				
			||||||
@@ -81,4 +81,62 @@ void AllGather::eval_gpu(
 | 
				
			|||||||
  signal_and_wait(in, out, stream());
 | 
					  signal_and_wait(in, out, stream());
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void Send::eval_gpu(
 | 
				
			||||||
 | 
					    const std::vector<array>& inputs,
 | 
				
			||||||
 | 
					    std::vector<array>& outputs) {
 | 
				
			||||||
 | 
					  assert(inputs.size() == 1);
 | 
				
			||||||
 | 
					  assert(outputs.size() == 1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto& in = inputs[0];
 | 
				
			||||||
 | 
					  auto& out = outputs[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Schedule an async send on the comm stream
 | 
				
			||||||
 | 
					  auto task = [in = in, out = out, group = group(), dst = dst_]() mutable {
 | 
				
			||||||
 | 
					    if (in.event().valid()) {
 | 
				
			||||||
 | 
					      in.event().wait();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    distributed::detail::send(group, in, dst);
 | 
				
			||||||
 | 
					    out.event().signal();
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					  scheduler::enqueue(detail::communication_stream(), std::move(task));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Encode a signal event for the input but not a wait since we don't need to
 | 
				
			||||||
 | 
					  // wait on the output.
 | 
				
			||||||
 | 
					  auto& s = stream();
 | 
				
			||||||
 | 
					  auto& d = metal::device(s.device);
 | 
				
			||||||
 | 
					  d.end_encoding(s.index);
 | 
				
			||||||
 | 
					  auto command_buffer = d.get_command_buffer(s.index);
 | 
				
			||||||
 | 
					  if (in.event().valid()) {
 | 
				
			||||||
 | 
					    command_buffer->encodeSignalEvent(
 | 
				
			||||||
 | 
					        static_cast<MTL::Event*>(in.event().raw_event().get()),
 | 
				
			||||||
 | 
					        in.event().value());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void Recv::eval_gpu(
 | 
				
			||||||
 | 
					    const std::vector<array>& inputs,
 | 
				
			||||||
 | 
					    std::vector<array>& outputs) {
 | 
				
			||||||
 | 
					  assert(inputs.size() == 0);
 | 
				
			||||||
 | 
					  assert(outputs.size() == 1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto& out = outputs[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  out.set_data(allocator::malloc_or_wait(out.nbytes()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Schedule an async recv on the comm stream
 | 
				
			||||||
 | 
					  auto task = [out = out, group = group(), src = src_]() mutable {
 | 
				
			||||||
 | 
					    distributed::detail::recv(group, out, src);
 | 
				
			||||||
 | 
					    out.event().signal();
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					  scheduler::enqueue(detail::communication_stream(), std::move(task));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Encode a wait event as there is no input for the recv to encode a signal.
 | 
				
			||||||
 | 
					  auto& s = stream();
 | 
				
			||||||
 | 
					  auto& d = metal::device(s.device);
 | 
				
			||||||
 | 
					  auto command_buffer = d.get_command_buffer(s.index);
 | 
				
			||||||
 | 
					  command_buffer->encodeWait(
 | 
				
			||||||
 | 
					      static_cast<MTL::Event*>(out.event().raw_event().get()),
 | 
				
			||||||
 | 
					      out.event().value());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace mlx::core::distributed
 | 
					} // namespace mlx::core::distributed
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -126,6 +126,8 @@ NO_GPU_MULTI(CustomKernel)
 | 
				
			|||||||
namespace distributed {
 | 
					namespace distributed {
 | 
				
			||||||
NO_GPU_MULTI(AllReduce)
 | 
					NO_GPU_MULTI(AllReduce)
 | 
				
			||||||
NO_GPU_MULTI(AllGather)
 | 
					NO_GPU_MULTI(AllGather)
 | 
				
			||||||
 | 
					NO_GPU_MULTI(Send)
 | 
				
			||||||
 | 
					NO_GPU_MULTI(Recv)
 | 
				
			||||||
} // namespace distributed
 | 
					} // namespace distributed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace mlx::core
 | 
					} // namespace mlx::core
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -50,17 +50,4 @@ struct Group {
 | 
				
			|||||||
 */
 | 
					 */
 | 
				
			||||||
Group init(bool strict = false);
 | 
					Group init(bool strict = false);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace detail {
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
/* Return the communication stream. */
 | 
					 | 
				
			||||||
Stream communication_stream();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
/* Perform an all reduce sum operation */
 | 
					 | 
				
			||||||
void all_sum(Group group, const array& input, array& output);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
/* Perform an all reduce sum operation */
 | 
					 | 
				
			||||||
void all_gather(Group group, const array& input, array& output);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
} // namespace detail
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
} // namespace mlx::core::distributed
 | 
					} // namespace mlx::core::distributed
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -12,7 +12,13 @@ Stream communication_stream();
 | 
				
			|||||||
/* Perform an all reduce sum operation */
 | 
					/* Perform an all reduce sum operation */
 | 
				
			||||||
void all_sum(Group group, const array& input, array& output);
 | 
					void all_sum(Group group, const array& input, array& output);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/* Perform an all reduce sum operation */
 | 
					/* Perform an all gather operation */
 | 
				
			||||||
void all_gather(Group group, const array& input, array& output);
 | 
					void all_gather(Group group, const array& input, array& output);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** Send an array to the dst rank */
 | 
				
			||||||
 | 
					void send(Group group, const array& input, int dst);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** Recv an array from the src rank */
 | 
				
			||||||
 | 
					void recv(Group group, array& out, int src);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace mlx::core::distributed::detail
 | 
					} // namespace mlx::core::distributed::detail
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -48,6 +48,8 @@ struct MPIWrapper {
 | 
				
			|||||||
    LOAD_SYMBOL(MPI_Comm_free, comm_free);
 | 
					    LOAD_SYMBOL(MPI_Comm_free, comm_free);
 | 
				
			||||||
    LOAD_SYMBOL(MPI_Allreduce, all_reduce);
 | 
					    LOAD_SYMBOL(MPI_Allreduce, all_reduce);
 | 
				
			||||||
    LOAD_SYMBOL(MPI_Allgather, all_gather);
 | 
					    LOAD_SYMBOL(MPI_Allgather, all_gather);
 | 
				
			||||||
 | 
					    LOAD_SYMBOL(MPI_Send, send);
 | 
				
			||||||
 | 
					    LOAD_SYMBOL(MPI_Recv, recv);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Objects
 | 
					    // Objects
 | 
				
			||||||
    LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_);
 | 
					    LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_);
 | 
				
			||||||
@@ -142,6 +144,8 @@ struct MPIWrapper {
 | 
				
			|||||||
      MPI_Comm);
 | 
					      MPI_Comm);
 | 
				
			||||||
  int (*comm_split)(MPI_Comm, int, int, MPI_Comm*);
 | 
					  int (*comm_split)(MPI_Comm, int, int, MPI_Comm*);
 | 
				
			||||||
  int (*comm_free)(MPI_Comm*);
 | 
					  int (*comm_free)(MPI_Comm*);
 | 
				
			||||||
 | 
					  int (*send)(const void*, int, MPI_Datatype, int, int, MPI_Comm);
 | 
				
			||||||
 | 
					  int (*recv)(void*, int, MPI_Datatype, int, int, MPI_Comm, MPI_Status*);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Objects
 | 
					  // Objects
 | 
				
			||||||
  MPI_Comm comm_world_;
 | 
					  MPI_Comm comm_world_;
 | 
				
			||||||
@@ -285,6 +289,29 @@ void all_gather(Group group, const array& input_, array& output) {
 | 
				
			|||||||
      to_comm(group));
 | 
					      to_comm(group));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void send(Group group, const array& input_, int dst) {
 | 
				
			||||||
 | 
					  array input = ensure_row_contiguous(input_);
 | 
				
			||||||
 | 
					  mpi().send(
 | 
				
			||||||
 | 
					      input.data<void>(),
 | 
				
			||||||
 | 
					      input.size(),
 | 
				
			||||||
 | 
					      mpi().datatype(input),
 | 
				
			||||||
 | 
					      dst,
 | 
				
			||||||
 | 
					      0,
 | 
				
			||||||
 | 
					      to_comm(group));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void recv(Group group, array& out, int src) {
 | 
				
			||||||
 | 
					  MPI_Status status;
 | 
				
			||||||
 | 
					  mpi().recv(
 | 
				
			||||||
 | 
					      out.data<void>(),
 | 
				
			||||||
 | 
					      out.size(),
 | 
				
			||||||
 | 
					      mpi().datatype(out),
 | 
				
			||||||
 | 
					      src,
 | 
				
			||||||
 | 
					      MPI_ANY_TAG,
 | 
				
			||||||
 | 
					      to_comm(group),
 | 
				
			||||||
 | 
					      &status);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace detail
 | 
					} // namespace detail
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace mlx::core::distributed
 | 
					} // namespace mlx::core::distributed
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -34,6 +34,8 @@ Stream communication_stream() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
void all_sum(Group group, const array& input, array& output) {}
 | 
					void all_sum(Group group, const array& input, array& output) {}
 | 
				
			||||||
void all_gather(Group group, const array& input, array& output) {}
 | 
					void all_gather(Group group, const array& input, array& output) {}
 | 
				
			||||||
 | 
					void send(Group group, const array& input, int dst) {}
 | 
				
			||||||
 | 
					void recv(Group group, array& out, int src) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace detail
 | 
					} // namespace detail
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,5 +1,7 @@
 | 
				
			|||||||
// Copyright © 2024 Apple Inc.
 | 
					// Copyright © 2024 Apple Inc.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <sstream>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "mlx/distributed/ops.h"
 | 
					#include "mlx/distributed/ops.h"
 | 
				
			||||||
#include "mlx/distributed/primitives.h"
 | 
					#include "mlx/distributed/primitives.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -57,4 +59,59 @@ array all_gather(
 | 
				
			|||||||
      {x});
 | 
					      {x});
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					array send(
 | 
				
			||||||
 | 
					    const array& x,
 | 
				
			||||||
 | 
					    int dst,
 | 
				
			||||||
 | 
					    std::optional<Group> group_ /* = std::nullopt */,
 | 
				
			||||||
 | 
					    StreamOrDevice s /* = {} */) {
 | 
				
			||||||
 | 
					  auto group = to_group(group_);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (group.size() == 1) {
 | 
				
			||||||
 | 
					    throw std::invalid_argument("Cannot send to a singleton group");
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (dst < 0 || dst >= group.size()) {
 | 
				
			||||||
 | 
					    std::ostringstream msg;
 | 
				
			||||||
 | 
					    msg << "Invalid destination=" << dst << " for a group of size "
 | 
				
			||||||
 | 
					        << group.size();
 | 
				
			||||||
 | 
					    throw std::invalid_argument(msg.str());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return array(
 | 
				
			||||||
 | 
					      {0}, int32, std::make_shared<Send>(to_stream(s), group, dst), {x});
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					array recv(
 | 
				
			||||||
 | 
					    std::vector<int> shape,
 | 
				
			||||||
 | 
					    Dtype dtype,
 | 
				
			||||||
 | 
					    int src,
 | 
				
			||||||
 | 
					    std::optional<Group> group_ /* = std::nullopt */,
 | 
				
			||||||
 | 
					    StreamOrDevice s /* = {} */) {
 | 
				
			||||||
 | 
					  auto group = to_group(group_);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (group.size() == 1) {
 | 
				
			||||||
 | 
					    throw std::invalid_argument("Cannot recv from a singleton group");
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (src < 0 || src >= group.size()) {
 | 
				
			||||||
 | 
					    std::ostringstream msg;
 | 
				
			||||||
 | 
					    msg << "Invalid source=" << src << " for a group of size " << group.size();
 | 
				
			||||||
 | 
					    throw std::invalid_argument(msg.str());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return array(
 | 
				
			||||||
 | 
					      std::move(shape),
 | 
				
			||||||
 | 
					      std::move(dtype),
 | 
				
			||||||
 | 
					      std::make_shared<Recv>(to_stream(s), group, src),
 | 
				
			||||||
 | 
					      std::vector<array>{});
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					array recv_like(
 | 
				
			||||||
 | 
					    const array& x,
 | 
				
			||||||
 | 
					    int src,
 | 
				
			||||||
 | 
					    std::optional<Group> group_ /* = std::nullopt */,
 | 
				
			||||||
 | 
					    StreamOrDevice s /* = {} */) {
 | 
				
			||||||
 | 
					  return recv(x.shape(), x.dtype(), src, group_, s);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace mlx::core::distributed
 | 
					} // namespace mlx::core::distributed
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -13,9 +13,29 @@ array all_sum(
 | 
				
			|||||||
    const array& x,
 | 
					    const array& x,
 | 
				
			||||||
    std::optional<Group> group = std::nullopt,
 | 
					    std::optional<Group> group = std::nullopt,
 | 
				
			||||||
    StreamOrDevice s = {});
 | 
					    StreamOrDevice s = {});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
array all_gather(
 | 
					array all_gather(
 | 
				
			||||||
    const array& x,
 | 
					    const array& x,
 | 
				
			||||||
    std::optional<Group> group = std::nullopt,
 | 
					    std::optional<Group> group = std::nullopt,
 | 
				
			||||||
    StreamOrDevice S = {});
 | 
					    StreamOrDevice S = {});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					array send(
 | 
				
			||||||
 | 
					    const array& x,
 | 
				
			||||||
 | 
					    int dst,
 | 
				
			||||||
 | 
					    std::optional<Group> group = std::nullopt,
 | 
				
			||||||
 | 
					    StreamOrDevice s = {});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					array recv(
 | 
				
			||||||
 | 
					    std::vector<int> shape,
 | 
				
			||||||
 | 
					    Dtype dtype,
 | 
				
			||||||
 | 
					    int src,
 | 
				
			||||||
 | 
					    std::optional<Group> group = std::nullopt,
 | 
				
			||||||
 | 
					    StreamOrDevice s = {});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					array recv_like(
 | 
				
			||||||
 | 
					    const array& x,
 | 
				
			||||||
 | 
					    int src,
 | 
				
			||||||
 | 
					    std::optional<Group> group = std::nullopt,
 | 
				
			||||||
 | 
					    StreamOrDevice s = {});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace mlx::core::distributed
 | 
					} // namespace mlx::core::distributed
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -35,7 +35,7 @@ std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
 | 
				
			|||||||
    const std::vector<int>& axes) {
 | 
					    const std::vector<int>& axes) {
 | 
				
			||||||
  switch (reduce_type_) {
 | 
					  switch (reduce_type_) {
 | 
				
			||||||
    case Sum:
 | 
					    case Sum:
 | 
				
			||||||
      return {{all_sum(inputs[0], group())}, axes};
 | 
					      return {{all_sum(inputs[0], group(), stream())}, axes};
 | 
				
			||||||
    default:
 | 
					    default:
 | 
				
			||||||
      throw std::runtime_error("Only all reduce sum is supported for now");
 | 
					      throw std::runtime_error("Only all reduce sum is supported for now");
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
@@ -47,7 +47,7 @@ std::vector<array> AllReduce::jvp(
 | 
				
			|||||||
    const std::vector<int>& argnums) {
 | 
					    const std::vector<int>& argnums) {
 | 
				
			||||||
  switch (reduce_type_) {
 | 
					  switch (reduce_type_) {
 | 
				
			||||||
    case Sum:
 | 
					    case Sum:
 | 
				
			||||||
      return {all_sum(tangents[0], group())};
 | 
					      return {all_sum(tangents[0], group(), stream())};
 | 
				
			||||||
    default:
 | 
					    default:
 | 
				
			||||||
      throw std::runtime_error("Only all reduce sum is supported for now");
 | 
					      throw std::runtime_error("Only all reduce sum is supported for now");
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
@@ -75,14 +75,14 @@ void AllGather::eval_cpu(
 | 
				
			|||||||
std::pair<std::vector<array>, std::vector<int>> AllGather::vmap(
 | 
					std::pair<std::vector<array>, std::vector<int>> AllGather::vmap(
 | 
				
			||||||
    const std::vector<array>& inputs,
 | 
					    const std::vector<array>& inputs,
 | 
				
			||||||
    const std::vector<int>& axes) {
 | 
					    const std::vector<int>& axes) {
 | 
				
			||||||
  return {{all_gather(inputs[0], group())}, axes};
 | 
					  return {{all_gather(inputs[0], group(), stream())}, axes};
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
std::vector<array> AllGather::jvp(
 | 
					std::vector<array> AllGather::jvp(
 | 
				
			||||||
    const std::vector<array>& primals,
 | 
					    const std::vector<array>& primals,
 | 
				
			||||||
    const std::vector<array>& tangents,
 | 
					    const std::vector<array>& tangents,
 | 
				
			||||||
    const std::vector<int>& argnums) {
 | 
					    const std::vector<int>& argnums) {
 | 
				
			||||||
  return {all_gather(tangents[0], group())};
 | 
					  return {all_gather(tangents[0], group(), stream())};
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
std::vector<array> AllGather::vjp(
 | 
					std::vector<array> AllGather::vjp(
 | 
				
			||||||
@@ -98,4 +98,29 @@ std::vector<array> AllGather::vjp(
 | 
				
			|||||||
  return {slice(cotangents[0], starts, stops)};
 | 
					  return {slice(cotangents[0], starts, stops)};
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void Send::eval_cpu(
 | 
				
			||||||
 | 
					    const std::vector<array>& inputs,
 | 
				
			||||||
 | 
					    std::vector<array>& outputs) {
 | 
				
			||||||
 | 
					  assert(inputs.size() == 1);
 | 
				
			||||||
 | 
					  assert(outputs.size() == 1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  distributed::detail::send(group(), inputs[0], dst_);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::pair<std::vector<array>, std::vector<int>> Send::vmap(
 | 
				
			||||||
 | 
					    const std::vector<array>& inputs,
 | 
				
			||||||
 | 
					    const std::vector<int>& axes) {
 | 
				
			||||||
 | 
					  return {{send(inputs[0], dst_, group(), stream())}, axes};
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void Recv::eval_cpu(
 | 
				
			||||||
 | 
					    const std::vector<array>& inputs,
 | 
				
			||||||
 | 
					    std::vector<array>& outputs) {
 | 
				
			||||||
 | 
					  assert(inputs.size() == 0);
 | 
				
			||||||
 | 
					  assert(outputs.size() == 1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
 | 
				
			||||||
 | 
					  distributed::detail::recv(group(), outputs[0], src_);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace mlx::core::distributed
 | 
					} // namespace mlx::core::distributed
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -97,4 +97,39 @@ class AllGather : public DistPrimitive {
 | 
				
			|||||||
  DEFINE_PRINT(AllGather);
 | 
					  DEFINE_PRINT(AllGather);
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Send : public DistPrimitive {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  Send(Stream stream, Group group, int dst)
 | 
				
			||||||
 | 
					      : DistPrimitive(stream, group), dst_(dst) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
 | 
				
			||||||
 | 
					      override;
 | 
				
			||||||
 | 
					  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
 | 
				
			||||||
 | 
					      override;
 | 
				
			||||||
 | 
					  std::pair<std::vector<array>, std::vector<int>> vmap(
 | 
				
			||||||
 | 
					      const std::vector<array>& inputs,
 | 
				
			||||||
 | 
					      const std::vector<int>& axes) override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  DEFINE_PRINT(Send);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 private:
 | 
				
			||||||
 | 
					  int dst_;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Recv : public DistPrimitive {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  Recv(Stream stream, Group group, int src)
 | 
				
			||||||
 | 
					      : DistPrimitive(stream, group), src_(src) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
 | 
				
			||||||
 | 
					      override;
 | 
				
			||||||
 | 
					  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
 | 
				
			||||||
 | 
					      override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  DEFINE_PRINT(Recv);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 private:
 | 
				
			||||||
 | 
					  int src_;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace mlx::core::distributed
 | 
					} // namespace mlx::core::distributed
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,6 +4,7 @@
 | 
				
			|||||||
#include <nanobind/stl/optional.h>
 | 
					#include <nanobind/stl/optional.h>
 | 
				
			||||||
#include <nanobind/stl/shared_ptr.h>
 | 
					#include <nanobind/stl/shared_ptr.h>
 | 
				
			||||||
#include <nanobind/stl/variant.h>
 | 
					#include <nanobind/stl/variant.h>
 | 
				
			||||||
 | 
					#include <nanobind/stl/vector.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "mlx/distributed/distributed.h"
 | 
					#include "mlx/distributed/distributed.h"
 | 
				
			||||||
#include "mlx/distributed/ops.h"
 | 
					#include "mlx/distributed/ops.h"
 | 
				
			||||||
@@ -121,4 +122,90 @@ void init_distributed(nb::module_& parent_module) {
 | 
				
			|||||||
        Returns:
 | 
					        Returns:
 | 
				
			||||||
          array: The concatenation of all ``x`` arrays.
 | 
					          array: The concatenation of all ``x`` arrays.
 | 
				
			||||||
      )pbdoc");
 | 
					      )pbdoc");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  m.def(
 | 
				
			||||||
 | 
					      "send",
 | 
				
			||||||
 | 
					      &distributed::send,
 | 
				
			||||||
 | 
					      "x"_a,
 | 
				
			||||||
 | 
					      "dst"_a,
 | 
				
			||||||
 | 
					      nb::kw_only(),
 | 
				
			||||||
 | 
					      "group"_a = nb::none(),
 | 
				
			||||||
 | 
					      "stream"_a = nb::none(),
 | 
				
			||||||
 | 
					      nb::sig(
 | 
				
			||||||
 | 
					          "def send(x: array, dst: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
 | 
				
			||||||
 | 
					      R"pbdoc(
 | 
				
			||||||
 | 
					        Send an array from the current process to the process that has rank
 | 
				
			||||||
 | 
					        ``dst`` in the group.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					          x (array): Input array.
 | 
				
			||||||
 | 
					          dst (int): Rank of the destination process in the group.
 | 
				
			||||||
 | 
					          group (Group): The group of processes that will participate in the
 | 
				
			||||||
 | 
					            sned. If set to ``None`` the global group is used. Default:
 | 
				
			||||||
 | 
					            ``None``.
 | 
				
			||||||
 | 
					          stream (Stream, optional): Stream or device. Defaults to ``None``
 | 
				
			||||||
 | 
					            in which case the default stream of the default device is used.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					          array: An empty array which when evaluated the send is performed.
 | 
				
			||||||
 | 
					      )pbdoc");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  m.def(
 | 
				
			||||||
 | 
					      "recv",
 | 
				
			||||||
 | 
					      &distributed::recv,
 | 
				
			||||||
 | 
					      "shape"_a,
 | 
				
			||||||
 | 
					      "dtype"_a,
 | 
				
			||||||
 | 
					      "src"_a,
 | 
				
			||||||
 | 
					      nb::kw_only(),
 | 
				
			||||||
 | 
					      "group"_a = nb::none(),
 | 
				
			||||||
 | 
					      "stream"_a = nb::none(),
 | 
				
			||||||
 | 
					      nb::sig(
 | 
				
			||||||
 | 
					          "def recv(shape: Sequence[int], dtype: Dtype, src: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
 | 
				
			||||||
 | 
					      R"pbdoc(
 | 
				
			||||||
 | 
					        Recv an array with shape ``shape`` and dtype ``dtype`` from process
 | 
				
			||||||
 | 
					        with rank ``src``.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					          shape (Tuple[int]): The shape of the array we are receiving.
 | 
				
			||||||
 | 
					          dtype (Dtype): The data type of the array we are receiving.
 | 
				
			||||||
 | 
					          src (int): Rank of the source process in the group.
 | 
				
			||||||
 | 
					          group (Group): The group of processes that will participate in the
 | 
				
			||||||
 | 
					            recv. If set to ``None`` the global group is used. Default:
 | 
				
			||||||
 | 
					            ``None``.
 | 
				
			||||||
 | 
					          stream (Stream, optional): Stream or device. Defaults to ``None``
 | 
				
			||||||
 | 
					            in which case the default stream of the default device is used.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					          array: The array that was received from ``src``.
 | 
				
			||||||
 | 
					      )pbdoc");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  m.def(
 | 
				
			||||||
 | 
					      "recv_like",
 | 
				
			||||||
 | 
					      &distributed::recv_like,
 | 
				
			||||||
 | 
					      "x"_a,
 | 
				
			||||||
 | 
					      "src"_a,
 | 
				
			||||||
 | 
					      nb::kw_only(),
 | 
				
			||||||
 | 
					      "group"_a = nb::none(),
 | 
				
			||||||
 | 
					      "stream"_a = nb::none(),
 | 
				
			||||||
 | 
					      nb::sig(
 | 
				
			||||||
 | 
					          "def recv_like(x: array, src: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
 | 
				
			||||||
 | 
					      R"pbdoc(
 | 
				
			||||||
 | 
					        Recv an array with shape and type like ``x`` from process with rank
 | 
				
			||||||
 | 
					        ``src``.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        It is equivalent to calling ``mx.distributed.recv(x.shape, x.dtype, src)``.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					          x (array): An array defining the shape and dtype of the array we are
 | 
				
			||||||
 | 
					            receiving.
 | 
				
			||||||
 | 
					          src (int): Rank of the source process in the group.
 | 
				
			||||||
 | 
					          group (Group): The group of processes that will participate in the
 | 
				
			||||||
 | 
					            recv. If set to ``None`` the global group is used. Default:
 | 
				
			||||||
 | 
					            ``None``.
 | 
				
			||||||
 | 
					          stream (Stream, optional): Stream or device. Defaults to ``None``
 | 
				
			||||||
 | 
					            in which case the default stream of the default device is used.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					          array: The array that was received from ``src``.
 | 
				
			||||||
 | 
					      )pbdoc");
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -93,6 +93,23 @@ class TestDistributed(mlx_tests.MLXTestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        self.assertTrue(mx.all(z == z_target))
 | 
					        self.assertTrue(mx.all(z == z_target))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_send_recv(self):
 | 
				
			||||||
 | 
					        world = mx.distributed.init()
 | 
				
			||||||
 | 
					        pairs = world.split(world.rank() // 2)
 | 
				
			||||||
 | 
					        neighbor = (pairs.rank() + 1) % 2
 | 
				
			||||||
 | 
					        send = pairs.rank() == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        x = mx.ones(10)
 | 
				
			||||||
 | 
					        for i in range(10):
 | 
				
			||||||
 | 
					            if send:
 | 
				
			||||||
 | 
					                mx.eval(mx.distributed.send(2 * x, neighbor, group=pairs))
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                x = mx.distributed.recv_like(x, neighbor, group=pairs)
 | 
				
			||||||
 | 
					                mx.eval(x)
 | 
				
			||||||
 | 
					            send = not send
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertTrue(mx.all(x == (1024 if pairs.rank() == 0 else 512)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    unittest.main()
 | 
					    unittest.main()
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user