From cdb59faea64383474fdb769b4c0e131408a99060 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 26 Aug 2024 23:01:37 -0700 Subject: [PATCH] Adds send/recv ops in distributed (#1366) --- docs/src/python/distributed.rst | 3 + mlx/backend/metal/distributed.cpp | 60 ++++++++++++++++++- mlx/backend/no_metal/primitives.cpp | 2 + mlx/distributed/distributed.h | 13 ----- mlx/distributed/distributed_impl.h | 8 ++- mlx/distributed/mpi/mpi.cpp | 27 +++++++++ mlx/distributed/no_distributed.cpp | 2 + mlx/distributed/ops.cpp | 57 ++++++++++++++++++ mlx/distributed/ops.h | 20 +++++++ mlx/distributed/primitives.cpp | 33 +++++++++-- mlx/distributed/primitives.h | 35 +++++++++++ python/src/distributed.cpp | 87 ++++++++++++++++++++++++++++ python/tests/mpi_test_distributed.py | 17 ++++++ 13 files changed, 345 insertions(+), 19 deletions(-) diff --git a/docs/src/python/distributed.rst b/docs/src/python/distributed.rst index cf9eae3f1..8b48d727e 100644 --- a/docs/src/python/distributed.rst +++ b/docs/src/python/distributed.rst @@ -17,3 +17,6 @@ made available. init all_sum all_gather + send + recv + recv_like diff --git a/mlx/backend/metal/distributed.cpp b/mlx/backend/metal/distributed.cpp index 64f7c979c..4cbd56af6 100644 --- a/mlx/backend/metal/distributed.cpp +++ b/mlx/backend/metal/distributed.cpp @@ -10,7 +10,7 @@ namespace mlx::core::distributed { -void signal_and_wait(const array& in, const array& out, const Stream s) { +void signal_and_wait(const array& in, const array& out, const Stream& s) { auto& d = metal::device(s.device); d.end_encoding(s.index); auto command_buffer = d.get_command_buffer(s.index); @@ -81,4 +81,62 @@ void AllGather::eval_gpu( signal_and_wait(in, out, stream()); } +void Send::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 1); + assert(outputs.size() == 1); + + auto& in = inputs[0]; + auto& out = outputs[0]; + + // Schedule an async send on the comm stream + auto task = [in = in, out = out, group = group(), dst = dst_]() mutable { + if (in.event().valid()) { + in.event().wait(); + } + distributed::detail::send(group, in, dst); + out.event().signal(); + }; + scheduler::enqueue(detail::communication_stream(), std::move(task)); + + // Encode a signal event for the input but not a wait since we don't need to + // wait on the output. + auto& s = stream(); + auto& d = metal::device(s.device); + d.end_encoding(s.index); + auto command_buffer = d.get_command_buffer(s.index); + if (in.event().valid()) { + command_buffer->encodeSignalEvent( + static_cast(in.event().raw_event().get()), + in.event().value()); + } +} + +void Recv::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 0); + assert(outputs.size() == 1); + + auto& out = outputs[0]; + + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + // Schedule an async recv on the comm stream + auto task = [out = out, group = group(), src = src_]() mutable { + distributed::detail::recv(group, out, src); + out.event().signal(); + }; + scheduler::enqueue(detail::communication_stream(), std::move(task)); + + // Encode a wait event as there is no input for the recv to encode a signal. + auto& s = stream(); + auto& d = metal::device(s.device); + auto command_buffer = d.get_command_buffer(s.index); + command_buffer->encodeWait( + static_cast(out.event().raw_event().get()), + out.event().value()); +} + } // namespace mlx::core::distributed diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index a98313e5d..544a2c6f2 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -126,6 +126,8 @@ NO_GPU_MULTI(CustomKernel) namespace distributed { NO_GPU_MULTI(AllReduce) NO_GPU_MULTI(AllGather) +NO_GPU_MULTI(Send) +NO_GPU_MULTI(Recv) } // namespace distributed } // namespace mlx::core diff --git a/mlx/distributed/distributed.h b/mlx/distributed/distributed.h index 44d40bc73..1ed82cb6a 100644 --- a/mlx/distributed/distributed.h +++ b/mlx/distributed/distributed.h @@ -50,17 +50,4 @@ struct Group { */ Group init(bool strict = false); -namespace detail { - -/* Return the communication stream. */ -Stream communication_stream(); - -/* Perform an all reduce sum operation */ -void all_sum(Group group, const array& input, array& output); - -/* Perform an all reduce sum operation */ -void all_gather(Group group, const array& input, array& output); - -} // namespace detail - } // namespace mlx::core::distributed diff --git a/mlx/distributed/distributed_impl.h b/mlx/distributed/distributed_impl.h index 42fd5aac3..7031e33f5 100644 --- a/mlx/distributed/distributed_impl.h +++ b/mlx/distributed/distributed_impl.h @@ -12,7 +12,13 @@ Stream communication_stream(); /* Perform an all reduce sum operation */ void all_sum(Group group, const array& input, array& output); -/* Perform an all reduce sum operation */ +/* Perform an all gather operation */ void all_gather(Group group, const array& input, array& output); +/** Send an array to the dst rank */ +void send(Group group, const array& input, int dst); + +/** Recv an array from the src rank */ +void recv(Group group, array& out, int src); + } // namespace mlx::core::distributed::detail diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index 5c60f9e97..4504ebecb 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -48,6 +48,8 @@ struct MPIWrapper { LOAD_SYMBOL(MPI_Comm_free, comm_free); LOAD_SYMBOL(MPI_Allreduce, all_reduce); LOAD_SYMBOL(MPI_Allgather, all_gather); + LOAD_SYMBOL(MPI_Send, send); + LOAD_SYMBOL(MPI_Recv, recv); // Objects LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_); @@ -142,6 +144,8 @@ struct MPIWrapper { MPI_Comm); int (*comm_split)(MPI_Comm, int, int, MPI_Comm*); int (*comm_free)(MPI_Comm*); + int (*send)(const void*, int, MPI_Datatype, int, int, MPI_Comm); + int (*recv)(void*, int, MPI_Datatype, int, int, MPI_Comm, MPI_Status*); // Objects MPI_Comm comm_world_; @@ -285,6 +289,29 @@ void all_gather(Group group, const array& input_, array& output) { to_comm(group)); } +void send(Group group, const array& input_, int dst) { + array input = ensure_row_contiguous(input_); + mpi().send( + input.data(), + input.size(), + mpi().datatype(input), + dst, + 0, + to_comm(group)); +} + +void recv(Group group, array& out, int src) { + MPI_Status status; + mpi().recv( + out.data(), + out.size(), + mpi().datatype(out), + src, + MPI_ANY_TAG, + to_comm(group), + &status); +} + } // namespace detail } // namespace mlx::core::distributed diff --git a/mlx/distributed/no_distributed.cpp b/mlx/distributed/no_distributed.cpp index 9c3e19227..009e3a715 100644 --- a/mlx/distributed/no_distributed.cpp +++ b/mlx/distributed/no_distributed.cpp @@ -34,6 +34,8 @@ Stream communication_stream() { void all_sum(Group group, const array& input, array& output) {} void all_gather(Group group, const array& input, array& output) {} +void send(Group group, const array& input, int dst) {} +void recv(Group group, array& out, int src) {} } // namespace detail diff --git a/mlx/distributed/ops.cpp b/mlx/distributed/ops.cpp index 54f8f483b..64b1cbae1 100644 --- a/mlx/distributed/ops.cpp +++ b/mlx/distributed/ops.cpp @@ -1,5 +1,7 @@ // Copyright © 2024 Apple Inc. +#include + #include "mlx/distributed/ops.h" #include "mlx/distributed/primitives.h" @@ -57,4 +59,59 @@ array all_gather( {x}); } +array send( + const array& x, + int dst, + std::optional group_ /* = std::nullopt */, + StreamOrDevice s /* = {} */) { + auto group = to_group(group_); + + if (group.size() == 1) { + throw std::invalid_argument("Cannot send to a singleton group"); + } + + if (dst < 0 || dst >= group.size()) { + std::ostringstream msg; + msg << "Invalid destination=" << dst << " for a group of size " + << group.size(); + throw std::invalid_argument(msg.str()); + } + + return array( + {0}, int32, std::make_shared(to_stream(s), group, dst), {x}); +} + +array recv( + std::vector shape, + Dtype dtype, + int src, + std::optional group_ /* = std::nullopt */, + StreamOrDevice s /* = {} */) { + auto group = to_group(group_); + + if (group.size() == 1) { + throw std::invalid_argument("Cannot recv from a singleton group"); + } + + if (src < 0 || src >= group.size()) { + std::ostringstream msg; + msg << "Invalid source=" << src << " for a group of size " << group.size(); + throw std::invalid_argument(msg.str()); + } + + return array( + std::move(shape), + std::move(dtype), + std::make_shared(to_stream(s), group, src), + std::vector{}); +} + +array recv_like( + const array& x, + int src, + std::optional group_ /* = std::nullopt */, + StreamOrDevice s /* = {} */) { + return recv(x.shape(), x.dtype(), src, group_, s); +} + } // namespace mlx::core::distributed diff --git a/mlx/distributed/ops.h b/mlx/distributed/ops.h index 85e9e99cc..5e9a06515 100644 --- a/mlx/distributed/ops.h +++ b/mlx/distributed/ops.h @@ -13,9 +13,29 @@ array all_sum( const array& x, std::optional group = std::nullopt, StreamOrDevice s = {}); + array all_gather( const array& x, std::optional group = std::nullopt, StreamOrDevice S = {}); +array send( + const array& x, + int dst, + std::optional group = std::nullopt, + StreamOrDevice s = {}); + +array recv( + std::vector shape, + Dtype dtype, + int src, + std::optional group = std::nullopt, + StreamOrDevice s = {}); + +array recv_like( + const array& x, + int src, + std::optional group = std::nullopt, + StreamOrDevice s = {}); + } // namespace mlx::core::distributed diff --git a/mlx/distributed/primitives.cpp b/mlx/distributed/primitives.cpp index a115ea087..7d8499b99 100644 --- a/mlx/distributed/primitives.cpp +++ b/mlx/distributed/primitives.cpp @@ -35,7 +35,7 @@ std::pair, std::vector> AllReduce::vmap( const std::vector& axes) { switch (reduce_type_) { case Sum: - return {{all_sum(inputs[0], group())}, axes}; + return {{all_sum(inputs[0], group(), stream())}, axes}; default: throw std::runtime_error("Only all reduce sum is supported for now"); } @@ -47,7 +47,7 @@ std::vector AllReduce::jvp( const std::vector& argnums) { switch (reduce_type_) { case Sum: - return {all_sum(tangents[0], group())}; + return {all_sum(tangents[0], group(), stream())}; default: throw std::runtime_error("Only all reduce sum is supported for now"); } @@ -75,14 +75,14 @@ void AllGather::eval_cpu( std::pair, std::vector> AllGather::vmap( const std::vector& inputs, const std::vector& axes) { - return {{all_gather(inputs[0], group())}, axes}; + return {{all_gather(inputs[0], group(), stream())}, axes}; } std::vector AllGather::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { - return {all_gather(tangents[0], group())}; + return {all_gather(tangents[0], group(), stream())}; } std::vector AllGather::vjp( @@ -98,4 +98,29 @@ std::vector AllGather::vjp( return {slice(cotangents[0], starts, stops)}; } +void Send::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 1); + assert(outputs.size() == 1); + + distributed::detail::send(group(), inputs[0], dst_); +} + +std::pair, std::vector> Send::vmap( + const std::vector& inputs, + const std::vector& axes) { + return {{send(inputs[0], dst_, group(), stream())}, axes}; +} + +void Recv::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 0); + assert(outputs.size() == 1); + + outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes())); + distributed::detail::recv(group(), outputs[0], src_); +} + } // namespace mlx::core::distributed diff --git a/mlx/distributed/primitives.h b/mlx/distributed/primitives.h index 4bf40b41c..7320e6cb6 100644 --- a/mlx/distributed/primitives.h +++ b/mlx/distributed/primitives.h @@ -97,4 +97,39 @@ class AllGather : public DistPrimitive { DEFINE_PRINT(AllGather); }; +class Send : public DistPrimitive { + public: + Send(Stream stream, Group group, int dst) + : DistPrimitive(stream, group), dst_(dst) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + std::pair, std::vector> vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_PRINT(Send); + + private: + int dst_; +}; + +class Recv : public DistPrimitive { + public: + Recv(Stream stream, Group group, int src) + : DistPrimitive(stream, group), src_(src) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_PRINT(Recv); + + private: + int src_; +}; + } // namespace mlx::core::distributed diff --git a/python/src/distributed.cpp b/python/src/distributed.cpp index 6395a642b..697b8bd58 100644 --- a/python/src/distributed.cpp +++ b/python/src/distributed.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include "mlx/distributed/distributed.h" #include "mlx/distributed/ops.h" @@ -121,4 +122,90 @@ void init_distributed(nb::module_& parent_module) { Returns: array: The concatenation of all ``x`` arrays. )pbdoc"); + + m.def( + "send", + &distributed::send, + "x"_a, + "dst"_a, + nb::kw_only(), + "group"_a = nb::none(), + "stream"_a = nb::none(), + nb::sig( + "def send(x: array, dst: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Send an array from the current process to the process that has rank + ``dst`` in the group. + + Args: + x (array): Input array. + dst (int): Rank of the destination process in the group. + group (Group): The group of processes that will participate in the + sned. If set to ``None`` the global group is used. Default: + ``None``. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: An empty array which when evaluated the send is performed. + )pbdoc"); + + m.def( + "recv", + &distributed::recv, + "shape"_a, + "dtype"_a, + "src"_a, + nb::kw_only(), + "group"_a = nb::none(), + "stream"_a = nb::none(), + nb::sig( + "def recv(shape: Sequence[int], dtype: Dtype, src: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Recv an array with shape ``shape`` and dtype ``dtype`` from process + with rank ``src``. + + Args: + shape (Tuple[int]): The shape of the array we are receiving. + dtype (Dtype): The data type of the array we are receiving. + src (int): Rank of the source process in the group. + group (Group): The group of processes that will participate in the + recv. If set to ``None`` the global group is used. Default: + ``None``. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: The array that was received from ``src``. + )pbdoc"); + + m.def( + "recv_like", + &distributed::recv_like, + "x"_a, + "src"_a, + nb::kw_only(), + "group"_a = nb::none(), + "stream"_a = nb::none(), + nb::sig( + "def recv_like(x: array, src: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Recv an array with shape and type like ``x`` from process with rank + ``src``. + + It is equivalent to calling ``mx.distributed.recv(x.shape, x.dtype, src)``. + + Args: + x (array): An array defining the shape and dtype of the array we are + receiving. + src (int): Rank of the source process in the group. + group (Group): The group of processes that will participate in the + recv. If set to ``None`` the global group is used. Default: + ``None``. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: The array that was received from ``src``. + )pbdoc"); } diff --git a/python/tests/mpi_test_distributed.py b/python/tests/mpi_test_distributed.py index 6c6e96009..44f3fd4ce 100644 --- a/python/tests/mpi_test_distributed.py +++ b/python/tests/mpi_test_distributed.py @@ -93,6 +93,23 @@ class TestDistributed(mlx_tests.MLXTestCase): self.assertTrue(mx.all(z == z_target)) + def test_send_recv(self): + world = mx.distributed.init() + pairs = world.split(world.rank() // 2) + neighbor = (pairs.rank() + 1) % 2 + send = pairs.rank() == 0 + + x = mx.ones(10) + for i in range(10): + if send: + mx.eval(mx.distributed.send(2 * x, neighbor, group=pairs)) + else: + x = mx.distributed.recv_like(x, neighbor, group=pairs) + mx.eval(x) + send = not send + + self.assertTrue(mx.all(x == (1024 if pairs.rank() == 0 else 512))) + if __name__ == "__main__": unittest.main()