From eeb5a0d63fa6cc2df5ab0dc34152e9868d7cd17d Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 20 Aug 2025 18:21:07 -0700 Subject: [PATCH] Put the decision of the comm stream to the group --- mlx/distributed/distributed.cpp | 8 ++++++ mlx/distributed/distributed.h | 2 ++ mlx/distributed/distributed_impl.h | 8 ++++++ mlx/distributed/mpi/mpi.cpp | 4 +++ mlx/distributed/nccl/nccl.cpp | 5 ++++ mlx/distributed/ops.cpp | 40 ++++++++++++------------------ mlx/distributed/ring/ring.cpp | 4 +++ 7 files changed, 47 insertions(+), 24 deletions(-) diff --git a/mlx/distributed/distributed.cpp b/mlx/distributed/distributed.cpp index a65329588..877533b55 100644 --- a/mlx/distributed/distributed.cpp +++ b/mlx/distributed/distributed.cpp @@ -13,6 +13,10 @@ namespace mlx::core::distributed { namespace detail { +Stream communication_stream(Group group, StreamOrDevice s /* = {} */) { + return group.raw_group()->communication_stream(s); +} + void all_sum(Group group, const array& input, array& output, Stream stream) { group.raw_group()->all_sum(input, output, stream); } @@ -39,6 +43,10 @@ void recv(Group group, array& out, int src, Stream stream) { class EmptyGroup : public GroupImpl { public: + Stream communication_stream(StreamOrDevice s) override { + return to_stream(s); + } + int rank() override { return 0; } diff --git a/mlx/distributed/distributed.h b/mlx/distributed/distributed.h index 0827cf3fe..fa5c42a1f 100644 --- a/mlx/distributed/distributed.h +++ b/mlx/distributed/distributed.h @@ -3,7 +3,9 @@ #pragma once #include + #include "mlx/array.h" +#include "mlx/utils.h" namespace mlx::core::distributed { diff --git a/mlx/distributed/distributed_impl.h b/mlx/distributed/distributed_impl.h index 8b0327131..c90b0ba47 100644 --- a/mlx/distributed/distributed_impl.h +++ b/mlx/distributed/distributed_impl.h @@ -13,10 +13,15 @@ class GroupImpl { public: virtual ~GroupImpl() {} + // Choose the stream this communication group can operate on + virtual Stream communication_stream(StreamOrDevice s = {}) = 0; + + // Group operations virtual int rank() = 0; virtual int size() = 0; virtual std::shared_ptr split(int color, int key = -1) = 0; + // Actual communication operations virtual void all_sum(const array& input, array& output, Stream stream) = 0; virtual void all_gather(const array& input, array& output, Stream stream) = 0; virtual void send(const array& input, int dst, Stream stream) = 0; @@ -25,6 +30,9 @@ class GroupImpl { virtual void all_min(const array& input, array& output, Stream stream) = 0; }; +/* Define the MLX stream that the communication should happen in. */ +Stream communication_stream(Group group, StreamOrDevice s = {}); + /* Perform an all reduce sum operation */ void all_sum(Group group, const array& input, array& output, Stream stream); diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index 6a440c319..494fb02dc 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -349,6 +349,10 @@ class MPIGroup : public GroupImpl { } } + Stream communication_stream(StreamOrDevice s) override { + return to_stream(s, Device::cpu); + } + int rank() override { if (rank_ < 0) { mpi().rank(comm_, &rank_); diff --git a/mlx/distributed/nccl/nccl.cpp b/mlx/distributed/nccl/nccl.cpp index 23176c81b..43af9c724 100644 --- a/mlx/distributed/nccl/nccl.cpp +++ b/mlx/distributed/nccl/nccl.cpp @@ -17,6 +17,7 @@ #include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed_impl.h" #include "mlx/dtype_utils.h" +#include "mlx/utils.h" namespace mlx::core::distributed::nccl { @@ -255,6 +256,10 @@ class NCCLGroup : public GroupImpl { initialized_ = false; } + Stream communication_stream(StreamOrDevice s) override { + return to_stream(s, Device::gpu); + } + int rank() override { return rank_; } diff --git a/mlx/distributed/ops.cpp b/mlx/distributed/ops.cpp index 900ae2c81..157bc2612 100644 --- a/mlx/distributed/ops.cpp +++ b/mlx/distributed/ops.cpp @@ -4,18 +4,10 @@ #include "mlx/backend/cuda/cuda.h" #include "mlx/backend/metal/metal.h" +#include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/ops.h" #include "mlx/distributed/primitives.h" -inline mlx::core::Device get_device() { - if (mlx::core::metal::is_available()) { - return mlx::core::Device::cpu; - } else if (mlx::core::cu::is_available()) { - return mlx::core::Device::gpu; - } - throw std::runtime_error("No available device for distributed operations."); -} - namespace mlx::core::distributed { namespace { @@ -35,15 +27,16 @@ array all_sum( std::optional group_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto group = to_group(group_); - auto dev = get_device(); if (group.size() == 1) { return x; } + auto stream = detail::communication_stream(group, s); + return array( x.shape(), x.dtype(), - std::make_shared(to_stream(s, dev), group, AllReduce::Sum), + std::make_shared(stream, group, AllReduce::Sum), {x}); } @@ -52,15 +45,16 @@ array all_max( std::optional group_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto group = to_group(group_); - auto dev = get_device(); if (group.size() == 1) { return x; } + auto stream = detail::communication_stream(group, s); + return array( x.shape(), x.dtype(), - std::make_shared(to_stream(s, dev), group, AllReduce::Max), + std::make_shared(stream, group, AllReduce::Max), {x}); } @@ -69,15 +63,16 @@ array all_min( std::optional group_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto group = to_group(group_); - auto dev = get_device(); if (group.size() == 1) { return x; } + auto stream = detail::communication_stream(group, s); + return array( x.shape(), x.dtype(), - std::make_shared(to_stream(s, dev), group, AllReduce::Min), + std::make_shared(stream, group, AllReduce::Min), {x}); } @@ -86,11 +81,11 @@ array all_gather( std::optional group_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto group = to_group(group_); - auto dev = get_device(); if (group.size() == 1) { return x; } + auto stream = detail::communication_stream(group, s); auto result_shape = x.shape(); if (result_shape.size() == 0) { @@ -101,7 +96,7 @@ array all_gather( return array( std::move(result_shape), x.dtype(), - std::make_shared(to_stream(s, dev), group), + std::make_shared(stream, group), {x}); } @@ -111,11 +106,11 @@ array send( std::optional group_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto group = to_group(group_); - auto dev = get_device(); if (group.size() == 1) { throw std::invalid_argument("Cannot send to a singleton group"); } + auto stream = detail::communication_stream(group, s); if (dst < 0 || dst >= group.size()) { std::ostringstream msg; @@ -125,10 +120,7 @@ array send( } return array( - x.shape(), - x.dtype(), - std::make_shared(to_stream(s, dev), group, dst), - {x}); + x.shape(), x.dtype(), std::make_shared(stream, group, dst), {x}); } array recv( @@ -138,11 +130,11 @@ array recv( std::optional group_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto group = to_group(group_); - auto dev = get_device(); if (group.size() == 1) { throw std::invalid_argument("Cannot recv from a singleton group"); } + auto stream = detail::communication_stream(group, s); if (src < 0 || src >= group.size()) { std::ostringstream msg; @@ -153,7 +145,7 @@ array recv( return array( std::move(shape), std::move(dtype), - std::make_shared(to_stream(s, dev), group, src), + std::make_shared(stream, group, src), std::vector{}); } diff --git a/mlx/distributed/ring/ring.cpp b/mlx/distributed/ring/ring.cpp index b31274e23..7c3dcf095 100644 --- a/mlx/distributed/ring/ring.cpp +++ b/mlx/distributed/ring/ring.cpp @@ -619,6 +619,10 @@ class RingGroup : public GroupImpl { } } + Stream communication_stream(StreamOrDevice s) override { + return to_stream(s, Device::cpu); + } + int rank() override { return rank_; }