mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Refactor distributed backend (#1752)
This commit is contained in:
parent
d5ec172c95
commit
545f84d905
@ -1,8 +1,7 @@
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp)
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp)
|
||||
|
||||
if(MPI_FOUND AND MLX_BUILD_CPU)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp)
|
||||
endif()
|
||||
|
99
mlx/distributed/distributed.cpp
Normal file
99
mlx/distributed/distributed.cpp
Normal file
@ -0,0 +1,99 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/distributed/mpi/mpi.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
namespace detail {
|
||||
|
||||
Stream communication_stream() {
|
||||
static Stream comm_stream = new_stream(Device::cpu);
|
||||
return comm_stream;
|
||||
}
|
||||
|
||||
void all_sum(Group group, const array& input, array& output) {
|
||||
group.raw_group()->all_sum(input, output);
|
||||
}
|
||||
|
||||
void all_gather(Group group, const array& input, array& output) {
|
||||
group.raw_group()->all_gather(input, output);
|
||||
}
|
||||
|
||||
void send(Group group, const array& input, int dst) {
|
||||
group.raw_group()->send(input, dst);
|
||||
}
|
||||
|
||||
void recv(Group group, array& out, int src) {
|
||||
group.raw_group()->recv(out, src);
|
||||
}
|
||||
|
||||
class EmptyGroup : public GroupImpl {
|
||||
public:
|
||||
int rank() override {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int size() override {
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||
throw std::runtime_error("Cannot split the distributed group further.");
|
||||
}
|
||||
|
||||
void all_sum(const array& input, array& output) override {
|
||||
throw std::runtime_error(
|
||||
"Communication not implemented in an empty distributed group.");
|
||||
}
|
||||
void all_gather(const array& input, array& output) override {
|
||||
throw std::runtime_error(
|
||||
"Communication not implemented in an empty distributed group.");
|
||||
}
|
||||
void send(const array& input, int dst) override {
|
||||
throw std::runtime_error(
|
||||
"Communication not implemented in an empty distributed group.");
|
||||
}
|
||||
void recv(array& out, int src) override {
|
||||
throw std::runtime_error(
|
||||
"Communication not implemented in an empty distributed group.");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
bool is_available() {
|
||||
return mpi::is_available();
|
||||
}
|
||||
|
||||
int Group::rank() const {
|
||||
return group_->rank();
|
||||
}
|
||||
|
||||
int Group::size() const {
|
||||
return group_->size();
|
||||
}
|
||||
|
||||
Group Group::split(int color, int key /* = -1 */) const {
|
||||
return Group(group_->split(color, key));
|
||||
}
|
||||
|
||||
Group init(bool strict /* = false */) {
|
||||
auto init_group = [strict]() {
|
||||
auto default_group = mpi::init(strict);
|
||||
if (default_group == nullptr) {
|
||||
default_group = std::make_shared<detail::EmptyGroup>();
|
||||
}
|
||||
return default_group;
|
||||
};
|
||||
static std::shared_ptr<detail::GroupImpl> default_group = init_group();
|
||||
|
||||
// Ensure the communication stream is alive before
|
||||
// the graph is evaluated
|
||||
detail::communication_stream();
|
||||
return Group(default_group);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed
|
@ -8,6 +8,11 @@
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
// Forward declaration of the base group implementation.
|
||||
namespace detail {
|
||||
class GroupImpl;
|
||||
};
|
||||
|
||||
/* Check if a communication backend is available */
|
||||
bool is_available();
|
||||
|
||||
@ -17,10 +22,10 @@ bool is_available();
|
||||
* order to define more granular communication.
|
||||
*/
|
||||
struct Group {
|
||||
Group(std::shared_ptr<void> group) : group_(group) {}
|
||||
Group(std::shared_ptr<detail::GroupImpl> group) : group_(std::move(group)) {}
|
||||
|
||||
int rank();
|
||||
int size();
|
||||
int rank() const;
|
||||
int size() const;
|
||||
|
||||
/**
|
||||
* Split the group according to the provided color. Namely processes that use
|
||||
@ -30,14 +35,14 @@ struct Group {
|
||||
* the key the smaller the rank. If the provided key is negative, then the
|
||||
* rank in the current group is used.
|
||||
*/
|
||||
Group split(int color, int key = -1);
|
||||
Group split(int color, int key = -1) const;
|
||||
|
||||
const std::shared_ptr<void>& raw_group() {
|
||||
const std::shared_ptr<detail::GroupImpl>& raw_group() const {
|
||||
return group_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<void> group_{nullptr};
|
||||
std::shared_ptr<detail::GroupImpl> group_{nullptr};
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -6,6 +6,21 @@
|
||||
|
||||
namespace mlx::core::distributed::detail {
|
||||
|
||||
/**
|
||||
* Abstract base class of a distributed group implementation.
|
||||
*/
|
||||
class GroupImpl {
|
||||
public:
|
||||
virtual int rank() = 0;
|
||||
virtual int size() = 0;
|
||||
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;
|
||||
|
||||
virtual void all_sum(const array& input, array& output) = 0;
|
||||
virtual void all_gather(const array& input, array& output) = 0;
|
||||
virtual void send(const array& input, int dst) = 0;
|
||||
virtual void recv(array& out, int src) = 0;
|
||||
};
|
||||
|
||||
/* Return the communication stream. */
|
||||
Stream communication_stream();
|
||||
|
||||
|
@ -1 +1,5 @@
|
||||
if(MPI_FOUND AND MLX_BUILD_CPU)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_mpi.cpp)
|
||||
endif()
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/distributed/mpi/mpi.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
#define LOAD_SYMBOL(symbol, variable) \
|
||||
@ -18,7 +19,9 @@
|
||||
} \
|
||||
}
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
namespace mlx::core::distributed::mpi {
|
||||
|
||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||
|
||||
namespace {
|
||||
|
||||
@ -233,11 +236,14 @@ MPIWrapper& mpi() {
|
||||
return wrapper;
|
||||
}
|
||||
|
||||
struct MPIGroupImpl {
|
||||
MPIGroupImpl() : comm_(nullptr), global_(true), rank_(0), size_(1) {}
|
||||
MPIGroupImpl(MPI_Comm comm, bool global)
|
||||
} // namespace
|
||||
|
||||
class MPIGroup : public GroupImpl {
|
||||
public:
|
||||
MPIGroup(MPI_Comm comm, bool global)
|
||||
: comm_(comm), global_(global), rank_(-1), size_(-1) {}
|
||||
~MPIGroupImpl() {
|
||||
|
||||
virtual ~MPIGroup() {
|
||||
if (global_) {
|
||||
mpi().finalize_safe();
|
||||
} else {
|
||||
@ -245,24 +251,74 @@ struct MPIGroupImpl {
|
||||
}
|
||||
}
|
||||
|
||||
MPI_Comm comm() {
|
||||
return comm_;
|
||||
}
|
||||
|
||||
int rank() {
|
||||
int rank() override {
|
||||
if (rank_ < 0) {
|
||||
mpi().rank(comm_, &rank_);
|
||||
}
|
||||
return rank_;
|
||||
}
|
||||
|
||||
int size() {
|
||||
int size() override {
|
||||
if (size_ < 0) {
|
||||
mpi().size(comm_, &size_);
|
||||
}
|
||||
return size_;
|
||||
}
|
||||
|
||||
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||
key = (key < 0) ? rank() : key;
|
||||
|
||||
MPI_Comm new_comm;
|
||||
int result = mpi().comm_split(comm_, color, key, &new_comm);
|
||||
if (result != MPI_SUCCESS) {
|
||||
throw std::runtime_error("MPI could not split this group");
|
||||
}
|
||||
|
||||
return std::make_shared<MPIGroup>(new_comm, false);
|
||||
}
|
||||
|
||||
void all_sum(const array& input_, array& output) override {
|
||||
array input = ensure_row_contiguous(input_);
|
||||
mpi().all_reduce(
|
||||
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE
|
||||
: input.data<void>(),
|
||||
output.data<void>(),
|
||||
input.size(),
|
||||
mpi().datatype(input),
|
||||
mpi().op_sum(input),
|
||||
comm_);
|
||||
}
|
||||
|
||||
void all_gather(const array& input_, array& output) override {
|
||||
array input = ensure_row_contiguous(input_);
|
||||
mpi().all_gather(
|
||||
input.data<void>(),
|
||||
input.size(),
|
||||
mpi().datatype(input),
|
||||
output.data<void>(),
|
||||
input.size(),
|
||||
mpi().datatype(output),
|
||||
comm_);
|
||||
}
|
||||
|
||||
void send(const array& input_, int dst) override {
|
||||
array input = ensure_row_contiguous(input_);
|
||||
mpi().send(
|
||||
input.data<void>(), input.size(), mpi().datatype(input), dst, 0, comm_);
|
||||
}
|
||||
|
||||
void recv(array& out, int src) override {
|
||||
MPI_Status status;
|
||||
mpi().recv(
|
||||
out.data<void>(),
|
||||
out.size(),
|
||||
mpi().datatype(out),
|
||||
src,
|
||||
MPI_ANY_TAG,
|
||||
comm_,
|
||||
&status);
|
||||
}
|
||||
|
||||
private:
|
||||
MPI_Comm comm_;
|
||||
bool global_;
|
||||
@ -270,112 +326,19 @@ struct MPIGroupImpl {
|
||||
int size_;
|
||||
};
|
||||
|
||||
MPI_Comm to_comm(Group& group) {
|
||||
return std::static_pointer_cast<MPIGroupImpl>(group.raw_group())->comm();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int Group::rank() {
|
||||
return std::static_pointer_cast<MPIGroupImpl>(group_)->rank();
|
||||
}
|
||||
|
||||
int Group::size() {
|
||||
return std::static_pointer_cast<MPIGroupImpl>(group_)->size();
|
||||
}
|
||||
|
||||
Group Group::split(int color, int key) {
|
||||
auto mpi_group = std::static_pointer_cast<MPIGroupImpl>(group_);
|
||||
|
||||
key = (key < 0) ? rank() : key;
|
||||
|
||||
MPI_Comm new_comm;
|
||||
int result = mpi().comm_split(mpi_group->comm(), color, key, &new_comm);
|
||||
if (result != MPI_SUCCESS) {
|
||||
throw std::runtime_error("MPI could not split this group");
|
||||
}
|
||||
|
||||
return Group(std::make_shared<MPIGroupImpl>(new_comm, false));
|
||||
}
|
||||
|
||||
bool is_available() {
|
||||
return mpi().is_available();
|
||||
}
|
||||
|
||||
Group init(bool strict /* = false */) {
|
||||
static std::shared_ptr<MPIGroupImpl> global_group = nullptr;
|
||||
|
||||
if (global_group == nullptr) {
|
||||
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||
if (!mpi().init_safe()) {
|
||||
if (strict) {
|
||||
throw std::runtime_error("Cannot initialize MPI");
|
||||
}
|
||||
global_group = std::make_shared<MPIGroupImpl>();
|
||||
} else {
|
||||
global_group = std::make_shared<MPIGroupImpl>(mpi().world(), true);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Ensure the communication stream is alive before
|
||||
// the graph is evaluated
|
||||
detail::communication_stream();
|
||||
return Group(global_group);
|
||||
return std::make_shared<MPIGroup>(mpi().world(), true);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
Stream communication_stream() {
|
||||
static Stream comm_stream = new_stream(Device::cpu);
|
||||
return comm_stream;
|
||||
}
|
||||
|
||||
void all_sum(Group group, const array& input_, array& output) {
|
||||
array input = ensure_row_contiguous(input_);
|
||||
mpi().all_reduce(
|
||||
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE
|
||||
: input.data<void>(),
|
||||
output.data<void>(),
|
||||
input.size(),
|
||||
mpi().datatype(input),
|
||||
mpi().op_sum(input),
|
||||
to_comm(group));
|
||||
}
|
||||
|
||||
void all_gather(Group group, const array& input_, array& output) {
|
||||
array input = ensure_row_contiguous(input_);
|
||||
mpi().all_gather(
|
||||
input.data<void>(),
|
||||
input.size(),
|
||||
mpi().datatype(input),
|
||||
output.data<void>(),
|
||||
input.size(),
|
||||
mpi().datatype(output),
|
||||
to_comm(group));
|
||||
}
|
||||
|
||||
void send(Group group, const array& input_, int dst) {
|
||||
array input = ensure_row_contiguous(input_);
|
||||
mpi().send(
|
||||
input.data<void>(),
|
||||
input.size(),
|
||||
mpi().datatype(input),
|
||||
dst,
|
||||
0,
|
||||
to_comm(group));
|
||||
}
|
||||
|
||||
void recv(Group group, array& out, int src) {
|
||||
MPI_Status status;
|
||||
mpi().recv(
|
||||
out.data<void>(),
|
||||
out.size(),
|
||||
mpi().datatype(out),
|
||||
src,
|
||||
MPI_ANY_TAG,
|
||||
to_comm(group),
|
||||
&status);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace mlx::core::distributed
|
||||
} // namespace mlx::core::distributed::mpi
|
||||
|
12
mlx/distributed/mpi/mpi.h
Normal file
12
mlx/distributed/mpi/mpi.h
Normal file
@ -0,0 +1,12 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
|
||||
namespace mlx::core::distributed::mpi {
|
||||
|
||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||
|
||||
bool is_available();
|
||||
std::shared_ptr<GroupImpl> init(bool strict = false);
|
||||
|
||||
} // namespace mlx::core::distributed::mpi
|
20
mlx/distributed/mpi/no_mpi.cpp
Normal file
20
mlx/distributed/mpi/no_mpi.cpp
Normal file
@ -0,0 +1,20 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/distributed/mpi/mpi.h"
|
||||
|
||||
namespace mlx::core::distributed::mpi {
|
||||
|
||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||
|
||||
bool is_available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||
if (strict) {
|
||||
throw std::runtime_error("Cannot initialize MPI");
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed::mpi
|
@ -1,42 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
int Group::rank() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int Group::size() {
|
||||
return 1;
|
||||
}
|
||||
|
||||
Group Group::split(int color, int key) {
|
||||
throw std::runtime_error("Cannot split the distributed group further");
|
||||
}
|
||||
|
||||
bool is_available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
Group init(bool strict /* = false */) {
|
||||
return Group(nullptr);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
Stream communication_stream() {
|
||||
static Stream comm_stream = new_stream(Device::cpu);
|
||||
return comm_stream;
|
||||
}
|
||||
|
||||
void all_sum(Group group, const array& input, array& output) {}
|
||||
void all_gather(Group group, const array& input, array& output) {}
|
||||
void send(Group group, const array& input, int dst) {}
|
||||
void recv(Group group, array& out, int src) {}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace mlx::core::distributed
|
Loading…
Reference in New Issue
Block a user