Refactor distributed backend (#1752)

This commit is contained in:
Angelos Katharopoulos 2025-01-06 17:33:15 -08:00 committed by GitHub
parent d5ec172c95
commit 545f84d905
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 242 additions and 167 deletions

View File

@ -1,8 +1,7 @@
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp) mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp)
if(MPI_FOUND AND MLX_BUILD_CPU) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp)
endif()

View File

@ -0,0 +1,99 @@
// Copyright © 2024 Apple Inc.
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/mpi/mpi.h"
#include "mlx/scheduler.h"
namespace mlx::core::distributed {
namespace detail {
Stream communication_stream() {
static Stream comm_stream = new_stream(Device::cpu);
return comm_stream;
}
void all_sum(Group group, const array& input, array& output) {
group.raw_group()->all_sum(input, output);
}
void all_gather(Group group, const array& input, array& output) {
group.raw_group()->all_gather(input, output);
}
void send(Group group, const array& input, int dst) {
group.raw_group()->send(input, dst);
}
void recv(Group group, array& out, int src) {
group.raw_group()->recv(out, src);
}
class EmptyGroup : public GroupImpl {
public:
int rank() override {
return 0;
}
int size() override {
return 1;
}
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
throw std::runtime_error("Cannot split the distributed group further.");
}
void all_sum(const array& input, array& output) override {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
void all_gather(const array& input, array& output) override {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
void send(const array& input, int dst) override {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
void recv(array& out, int src) override {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
};
} // namespace detail
bool is_available() {
return mpi::is_available();
}
int Group::rank() const {
return group_->rank();
}
int Group::size() const {
return group_->size();
}
Group Group::split(int color, int key /* = -1 */) const {
return Group(group_->split(color, key));
}
Group init(bool strict /* = false */) {
auto init_group = [strict]() {
auto default_group = mpi::init(strict);
if (default_group == nullptr) {
default_group = std::make_shared<detail::EmptyGroup>();
}
return default_group;
};
static std::shared_ptr<detail::GroupImpl> default_group = init_group();
// Ensure the communication stream is alive before
// the graph is evaluated
detail::communication_stream();
return Group(default_group);
}
} // namespace mlx::core::distributed

View File

@ -8,6 +8,11 @@
namespace mlx::core::distributed { namespace mlx::core::distributed {
// Forward declaration of the base group implementation.
namespace detail {
class GroupImpl;
};
/* Check if a communication backend is available */ /* Check if a communication backend is available */
bool is_available(); bool is_available();
@ -17,10 +22,10 @@ bool is_available();
* order to define more granular communication. * order to define more granular communication.
*/ */
struct Group { struct Group {
Group(std::shared_ptr<void> group) : group_(group) {} Group(std::shared_ptr<detail::GroupImpl> group) : group_(std::move(group)) {}
int rank(); int rank() const;
int size(); int size() const;
/** /**
* Split the group according to the provided color. Namely processes that use * Split the group according to the provided color. Namely processes that use
@ -30,14 +35,14 @@ struct Group {
* the key the smaller the rank. If the provided key is negative, then the * the key the smaller the rank. If the provided key is negative, then the
* rank in the current group is used. * rank in the current group is used.
*/ */
Group split(int color, int key = -1); Group split(int color, int key = -1) const;
const std::shared_ptr<void>& raw_group() { const std::shared_ptr<detail::GroupImpl>& raw_group() const {
return group_; return group_;
} }
private: private:
std::shared_ptr<void> group_{nullptr}; std::shared_ptr<detail::GroupImpl> group_{nullptr};
}; };
/** /**

View File

@ -6,6 +6,21 @@
namespace mlx::core::distributed::detail { namespace mlx::core::distributed::detail {
/**
* Abstract base class of a distributed group implementation.
*/
class GroupImpl {
public:
virtual int rank() = 0;
virtual int size() = 0;
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;
virtual void all_sum(const array& input, array& output) = 0;
virtual void all_gather(const array& input, array& output) = 0;
virtual void send(const array& input, int dst) = 0;
virtual void recv(array& out, int src) = 0;
};
/* Return the communication stream. */ /* Return the communication stream. */
Stream communication_stream(); Stream communication_stream();

View File

@ -1 +1,5 @@
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp) if(MPI_FOUND AND MLX_BUILD_CPU)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_mpi.cpp)
endif()

View File

@ -6,6 +6,7 @@
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/mpi/mpi.h"
#include "mlx/scheduler.h" #include "mlx/scheduler.h"
#define LOAD_SYMBOL(symbol, variable) \ #define LOAD_SYMBOL(symbol, variable) \
@ -18,7 +19,9 @@
} \ } \
} }
namespace mlx::core::distributed { namespace mlx::core::distributed::mpi {
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
namespace { namespace {
@ -233,11 +236,14 @@ MPIWrapper& mpi() {
return wrapper; return wrapper;
} }
struct MPIGroupImpl { } // namespace
MPIGroupImpl() : comm_(nullptr), global_(true), rank_(0), size_(1) {}
MPIGroupImpl(MPI_Comm comm, bool global) class MPIGroup : public GroupImpl {
public:
MPIGroup(MPI_Comm comm, bool global)
: comm_(comm), global_(global), rank_(-1), size_(-1) {} : comm_(comm), global_(global), rank_(-1), size_(-1) {}
~MPIGroupImpl() {
virtual ~MPIGroup() {
if (global_) { if (global_) {
mpi().finalize_safe(); mpi().finalize_safe();
} else { } else {
@ -245,24 +251,74 @@ struct MPIGroupImpl {
} }
} }
MPI_Comm comm() { int rank() override {
return comm_;
}
int rank() {
if (rank_ < 0) { if (rank_ < 0) {
mpi().rank(comm_, &rank_); mpi().rank(comm_, &rank_);
} }
return rank_; return rank_;
} }
int size() { int size() override {
if (size_ < 0) { if (size_ < 0) {
mpi().size(comm_, &size_); mpi().size(comm_, &size_);
} }
return size_; return size_;
} }
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
key = (key < 0) ? rank() : key;
MPI_Comm new_comm;
int result = mpi().comm_split(comm_, color, key, &new_comm);
if (result != MPI_SUCCESS) {
throw std::runtime_error("MPI could not split this group");
}
return std::make_shared<MPIGroup>(new_comm, false);
}
void all_sum(const array& input_, array& output) override {
array input = ensure_row_contiguous(input_);
mpi().all_reduce(
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE
: input.data<void>(),
output.data<void>(),
input.size(),
mpi().datatype(input),
mpi().op_sum(input),
comm_);
}
void all_gather(const array& input_, array& output) override {
array input = ensure_row_contiguous(input_);
mpi().all_gather(
input.data<void>(),
input.size(),
mpi().datatype(input),
output.data<void>(),
input.size(),
mpi().datatype(output),
comm_);
}
void send(const array& input_, int dst) override {
array input = ensure_row_contiguous(input_);
mpi().send(
input.data<void>(), input.size(), mpi().datatype(input), dst, 0, comm_);
}
void recv(array& out, int src) override {
MPI_Status status;
mpi().recv(
out.data<void>(),
out.size(),
mpi().datatype(out),
src,
MPI_ANY_TAG,
comm_,
&status);
}
private: private:
MPI_Comm comm_; MPI_Comm comm_;
bool global_; bool global_;
@ -270,112 +326,19 @@ struct MPIGroupImpl {
int size_; int size_;
}; };
MPI_Comm to_comm(Group& group) {
return std::static_pointer_cast<MPIGroupImpl>(group.raw_group())->comm();
}
} // namespace
int Group::rank() {
return std::static_pointer_cast<MPIGroupImpl>(group_)->rank();
}
int Group::size() {
return std::static_pointer_cast<MPIGroupImpl>(group_)->size();
}
Group Group::split(int color, int key) {
auto mpi_group = std::static_pointer_cast<MPIGroupImpl>(group_);
key = (key < 0) ? rank() : key;
MPI_Comm new_comm;
int result = mpi().comm_split(mpi_group->comm(), color, key, &new_comm);
if (result != MPI_SUCCESS) {
throw std::runtime_error("MPI could not split this group");
}
return Group(std::make_shared<MPIGroupImpl>(new_comm, false));
}
bool is_available() { bool is_available() {
return mpi().is_available(); return mpi().is_available();
} }
Group init(bool strict /* = false */) { std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
static std::shared_ptr<MPIGroupImpl> global_group = nullptr; if (!mpi().init_safe()) {
if (strict) {
if (global_group == nullptr) { throw std::runtime_error("Cannot initialize MPI");
if (!mpi().init_safe()) {
if (strict) {
throw std::runtime_error("Cannot initialize MPI");
}
global_group = std::make_shared<MPIGroupImpl>();
} else {
global_group = std::make_shared<MPIGroupImpl>(mpi().world(), true);
} }
return nullptr;
} }
// Ensure the communication stream is alive before return std::make_shared<MPIGroup>(mpi().world(), true);
// the graph is evaluated
detail::communication_stream();
return Group(global_group);
} }
namespace detail { } // namespace mlx::core::distributed::mpi
Stream communication_stream() {
static Stream comm_stream = new_stream(Device::cpu);
return comm_stream;
}
void all_sum(Group group, const array& input_, array& output) {
array input = ensure_row_contiguous(input_);
mpi().all_reduce(
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE
: input.data<void>(),
output.data<void>(),
input.size(),
mpi().datatype(input),
mpi().op_sum(input),
to_comm(group));
}
void all_gather(Group group, const array& input_, array& output) {
array input = ensure_row_contiguous(input_);
mpi().all_gather(
input.data<void>(),
input.size(),
mpi().datatype(input),
output.data<void>(),
input.size(),
mpi().datatype(output),
to_comm(group));
}
void send(Group group, const array& input_, int dst) {
array input = ensure_row_contiguous(input_);
mpi().send(
input.data<void>(),
input.size(),
mpi().datatype(input),
dst,
0,
to_comm(group));
}
void recv(Group group, array& out, int src) {
MPI_Status status;
mpi().recv(
out.data<void>(),
out.size(),
mpi().datatype(out),
src,
MPI_ANY_TAG,
to_comm(group),
&status);
}
} // namespace detail
} // namespace mlx::core::distributed

12
mlx/distributed/mpi/mpi.h Normal file
View File

@ -0,0 +1,12 @@
// Copyright © 2024 Apple Inc.
#include "mlx/distributed/distributed.h"
namespace mlx::core::distributed::mpi {
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
bool is_available();
std::shared_ptr<GroupImpl> init(bool strict = false);
} // namespace mlx::core::distributed::mpi

View File

@ -0,0 +1,20 @@
// Copyright © 2024 Apple Inc.
#include "mlx/distributed/mpi/mpi.h"
namespace mlx::core::distributed::mpi {
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
bool is_available() {
return false;
}
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
if (strict) {
throw std::runtime_error("Cannot initialize MPI");
}
return nullptr;
}
} // namespace mlx::core::distributed::mpi

View File

@ -1,42 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
namespace mlx::core::distributed {
int Group::rank() {
return 0;
}
int Group::size() {
return 1;
}
Group Group::split(int color, int key) {
throw std::runtime_error("Cannot split the distributed group further");
}
bool is_available() {
return false;
}
Group init(bool strict /* = false */) {
return Group(nullptr);
}
namespace detail {
Stream communication_stream() {
static Stream comm_stream = new_stream(Device::cpu);
return comm_stream;
}
void all_sum(Group group, const array& input, array& output) {}
void all_gather(Group group, const array& input, array& output) {}
void send(Group group, const array& input, int dst) {}
void recv(Group group, array& out, int src) {}
} // namespace detail
} // namespace mlx::core::distributed