diff --git a/mlx/distributed/CMakeLists.txt b/mlx/distributed/CMakeLists.txt index 4009196eb..8fd081b84 100644 --- a/mlx/distributed/CMakeLists.txt +++ b/mlx/distributed/CMakeLists.txt @@ -1,8 +1,7 @@ -target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp) +target_sources( + mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp) -if(MPI_FOUND AND MLX_BUILD_CPU) - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi) -else() - target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp) -endif() +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi) diff --git a/mlx/distributed/distributed.cpp b/mlx/distributed/distributed.cpp new file mode 100644 index 000000000..34d9583fa --- /dev/null +++ b/mlx/distributed/distributed.cpp @@ -0,0 +1,99 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/distributed/distributed.h" +#include "mlx/distributed/distributed_impl.h" +#include "mlx/distributed/mpi/mpi.h" +#include "mlx/scheduler.h" + +namespace mlx::core::distributed { + +namespace detail { + +Stream communication_stream() { + static Stream comm_stream = new_stream(Device::cpu); + return comm_stream; +} + +void all_sum(Group group, const array& input, array& output) { + group.raw_group()->all_sum(input, output); +} + +void all_gather(Group group, const array& input, array& output) { + group.raw_group()->all_gather(input, output); +} + +void send(Group group, const array& input, int dst) { + group.raw_group()->send(input, dst); +} + +void recv(Group group, array& out, int src) { + group.raw_group()->recv(out, src); +} + +class EmptyGroup : public GroupImpl { + public: + int rank() override { + return 0; + } + + int size() override { + return 1; + } + + std::shared_ptr split(int color, int key = -1) override { + throw std::runtime_error("Cannot split the distributed group further."); + } + + void all_sum(const array& input, array& output) override { + throw std::runtime_error( + "Communication not implemented in an empty distributed group."); + } + void all_gather(const array& input, array& output) override { + throw std::runtime_error( + "Communication not implemented in an empty distributed group."); + } + void send(const array& input, int dst) override { + throw std::runtime_error( + "Communication not implemented in an empty distributed group."); + } + void recv(array& out, int src) override { + throw std::runtime_error( + "Communication not implemented in an empty distributed group."); + } +}; + +} // namespace detail + +bool is_available() { + return mpi::is_available(); +} + +int Group::rank() const { + return group_->rank(); +} + +int Group::size() const { + return group_->size(); +} + +Group Group::split(int color, int key /* = -1 */) const { + return Group(group_->split(color, key)); +} + +Group init(bool strict /* = false */) { + auto init_group = [strict]() { + auto default_group = mpi::init(strict); + if (default_group == nullptr) { + default_group = std::make_shared(); + } + return default_group; + }; + static std::shared_ptr default_group = init_group(); + + // Ensure the communication stream is alive before + // the graph is evaluated + detail::communication_stream(); + return Group(default_group); +} + +} // namespace mlx::core::distributed diff --git a/mlx/distributed/distributed.h b/mlx/distributed/distributed.h index 1ed82cb6a..c06d10756 100644 --- a/mlx/distributed/distributed.h +++ b/mlx/distributed/distributed.h @@ -8,6 +8,11 @@ namespace mlx::core::distributed { +// Forward declaration of the base group implementation. +namespace detail { +class GroupImpl; +}; + /* Check if a communication backend is available */ bool is_available(); @@ -17,10 +22,10 @@ bool is_available(); * order to define more granular communication. */ struct Group { - Group(std::shared_ptr group) : group_(group) {} + Group(std::shared_ptr group) : group_(std::move(group)) {} - int rank(); - int size(); + int rank() const; + int size() const; /** * Split the group according to the provided color. Namely processes that use @@ -30,14 +35,14 @@ struct Group { * the key the smaller the rank. If the provided key is negative, then the * rank in the current group is used. */ - Group split(int color, int key = -1); + Group split(int color, int key = -1) const; - const std::shared_ptr& raw_group() { + const std::shared_ptr& raw_group() const { return group_; } private: - std::shared_ptr group_{nullptr}; + std::shared_ptr group_{nullptr}; }; /** diff --git a/mlx/distributed/distributed_impl.h b/mlx/distributed/distributed_impl.h index 7031e33f5..fdcbf777d 100644 --- a/mlx/distributed/distributed_impl.h +++ b/mlx/distributed/distributed_impl.h @@ -6,6 +6,21 @@ namespace mlx::core::distributed::detail { +/** + * Abstract base class of a distributed group implementation. + */ +class GroupImpl { + public: + virtual int rank() = 0; + virtual int size() = 0; + virtual std::shared_ptr split(int color, int key = -1) = 0; + + virtual void all_sum(const array& input, array& output) = 0; + virtual void all_gather(const array& input, array& output) = 0; + virtual void send(const array& input, int dst) = 0; + virtual void recv(array& out, int src) = 0; +}; + /* Return the communication stream. */ Stream communication_stream(); diff --git a/mlx/distributed/mpi/CMakeLists.txt b/mlx/distributed/mpi/CMakeLists.txt index 0e47d4347..7063a101f 100644 --- a/mlx/distributed/mpi/CMakeLists.txt +++ b/mlx/distributed/mpi/CMakeLists.txt @@ -1 +1,5 @@ -target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp) +if(MPI_FOUND AND MLX_BUILD_CPU) + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp) +else() + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_mpi.cpp) +endif() diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index 3223832e5..b233df1b5 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -6,6 +6,7 @@ #include "mlx/backend/common/copy.h" #include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed_impl.h" +#include "mlx/distributed/mpi/mpi.h" #include "mlx/scheduler.h" #define LOAD_SYMBOL(symbol, variable) \ @@ -18,7 +19,9 @@ } \ } -namespace mlx::core::distributed { +namespace mlx::core::distributed::mpi { + +using GroupImpl = mlx::core::distributed::detail::GroupImpl; namespace { @@ -233,11 +236,14 @@ MPIWrapper& mpi() { return wrapper; } -struct MPIGroupImpl { - MPIGroupImpl() : comm_(nullptr), global_(true), rank_(0), size_(1) {} - MPIGroupImpl(MPI_Comm comm, bool global) +} // namespace + +class MPIGroup : public GroupImpl { + public: + MPIGroup(MPI_Comm comm, bool global) : comm_(comm), global_(global), rank_(-1), size_(-1) {} - ~MPIGroupImpl() { + + virtual ~MPIGroup() { if (global_) { mpi().finalize_safe(); } else { @@ -245,24 +251,74 @@ struct MPIGroupImpl { } } - MPI_Comm comm() { - return comm_; - } - - int rank() { + int rank() override { if (rank_ < 0) { mpi().rank(comm_, &rank_); } return rank_; } - int size() { + int size() override { if (size_ < 0) { mpi().size(comm_, &size_); } return size_; } + std::shared_ptr split(int color, int key = -1) override { + key = (key < 0) ? rank() : key; + + MPI_Comm new_comm; + int result = mpi().comm_split(comm_, color, key, &new_comm); + if (result != MPI_SUCCESS) { + throw std::runtime_error("MPI could not split this group"); + } + + return std::make_shared(new_comm, false); + } + + void all_sum(const array& input_, array& output) override { + array input = ensure_row_contiguous(input_); + mpi().all_reduce( + (input.data() == output.data()) ? MPI_IN_PLACE + : input.data(), + output.data(), + input.size(), + mpi().datatype(input), + mpi().op_sum(input), + comm_); + } + + void all_gather(const array& input_, array& output) override { + array input = ensure_row_contiguous(input_); + mpi().all_gather( + input.data(), + input.size(), + mpi().datatype(input), + output.data(), + input.size(), + mpi().datatype(output), + comm_); + } + + void send(const array& input_, int dst) override { + array input = ensure_row_contiguous(input_); + mpi().send( + input.data(), input.size(), mpi().datatype(input), dst, 0, comm_); + } + + void recv(array& out, int src) override { + MPI_Status status; + mpi().recv( + out.data(), + out.size(), + mpi().datatype(out), + src, + MPI_ANY_TAG, + comm_, + &status); + } + private: MPI_Comm comm_; bool global_; @@ -270,112 +326,19 @@ struct MPIGroupImpl { int size_; }; -MPI_Comm to_comm(Group& group) { - return std::static_pointer_cast(group.raw_group())->comm(); -} - -} // namespace - -int Group::rank() { - return std::static_pointer_cast(group_)->rank(); -} - -int Group::size() { - return std::static_pointer_cast(group_)->size(); -} - -Group Group::split(int color, int key) { - auto mpi_group = std::static_pointer_cast(group_); - - key = (key < 0) ? rank() : key; - - MPI_Comm new_comm; - int result = mpi().comm_split(mpi_group->comm(), color, key, &new_comm); - if (result != MPI_SUCCESS) { - throw std::runtime_error("MPI could not split this group"); - } - - return Group(std::make_shared(new_comm, false)); -} - bool is_available() { return mpi().is_available(); } -Group init(bool strict /* = false */) { - static std::shared_ptr global_group = nullptr; - - if (global_group == nullptr) { - if (!mpi().init_safe()) { - if (strict) { - throw std::runtime_error("Cannot initialize MPI"); - } - global_group = std::make_shared(); - } else { - global_group = std::make_shared(mpi().world(), true); +std::shared_ptr init(bool strict /* = false */) { + if (!mpi().init_safe()) { + if (strict) { + throw std::runtime_error("Cannot initialize MPI"); } + return nullptr; } - // Ensure the communication stream is alive before - // the graph is evaluated - detail::communication_stream(); - return Group(global_group); + return std::make_shared(mpi().world(), true); } -namespace detail { - -Stream communication_stream() { - static Stream comm_stream = new_stream(Device::cpu); - return comm_stream; -} - -void all_sum(Group group, const array& input_, array& output) { - array input = ensure_row_contiguous(input_); - mpi().all_reduce( - (input.data() == output.data()) ? MPI_IN_PLACE - : input.data(), - output.data(), - input.size(), - mpi().datatype(input), - mpi().op_sum(input), - to_comm(group)); -} - -void all_gather(Group group, const array& input_, array& output) { - array input = ensure_row_contiguous(input_); - mpi().all_gather( - input.data(), - input.size(), - mpi().datatype(input), - output.data(), - input.size(), - mpi().datatype(output), - to_comm(group)); -} - -void send(Group group, const array& input_, int dst) { - array input = ensure_row_contiguous(input_); - mpi().send( - input.data(), - input.size(), - mpi().datatype(input), - dst, - 0, - to_comm(group)); -} - -void recv(Group group, array& out, int src) { - MPI_Status status; - mpi().recv( - out.data(), - out.size(), - mpi().datatype(out), - src, - MPI_ANY_TAG, - to_comm(group), - &status); -} - -} // namespace detail - -} // namespace mlx::core::distributed +} // namespace mlx::core::distributed::mpi diff --git a/mlx/distributed/mpi/mpi.h b/mlx/distributed/mpi/mpi.h new file mode 100644 index 000000000..cd11a4785 --- /dev/null +++ b/mlx/distributed/mpi/mpi.h @@ -0,0 +1,12 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/distributed/distributed.h" + +namespace mlx::core::distributed::mpi { + +using GroupImpl = mlx::core::distributed::detail::GroupImpl; + +bool is_available(); +std::shared_ptr init(bool strict = false); + +} // namespace mlx::core::distributed::mpi diff --git a/mlx/distributed/mpi/no_mpi.cpp b/mlx/distributed/mpi/no_mpi.cpp new file mode 100644 index 000000000..4a7fc6653 --- /dev/null +++ b/mlx/distributed/mpi/no_mpi.cpp @@ -0,0 +1,20 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/distributed/mpi/mpi.h" + +namespace mlx::core::distributed::mpi { + +using GroupImpl = mlx::core::distributed::detail::GroupImpl; + +bool is_available() { + return false; +} + +std::shared_ptr init(bool strict /* = false */) { + if (strict) { + throw std::runtime_error("Cannot initialize MPI"); + } + return nullptr; +} + +} // namespace mlx::core::distributed::mpi diff --git a/mlx/distributed/no_distributed.cpp b/mlx/distributed/no_distributed.cpp deleted file mode 100644 index 009e3a715..000000000 --- a/mlx/distributed/no_distributed.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#include "mlx/distributed/distributed.h" -#include "mlx/distributed/distributed_impl.h" - -namespace mlx::core::distributed { - -int Group::rank() { - return 0; -} - -int Group::size() { - return 1; -} - -Group Group::split(int color, int key) { - throw std::runtime_error("Cannot split the distributed group further"); -} - -bool is_available() { - return false; -} - -Group init(bool strict /* = false */) { - return Group(nullptr); -} - -namespace detail { - -Stream communication_stream() { - static Stream comm_stream = new_stream(Device::cpu); - return comm_stream; -} - -void all_sum(Group group, const array& input, array& output) {} -void all_gather(Group group, const array& input, array& output) {} -void send(Group group, const array& input, int dst) {} -void recv(Group group, array& out, int src) {} - -} // namespace detail - -} // namespace mlx::core::distributed