mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Refactor distributed backend (#1752)
This commit is contained in:
parent
d5ec172c95
commit
545f84d905
@ -1,8 +1,7 @@
|
|||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
target_sources(
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp)
|
mlx
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp)
|
||||||
|
|
||||||
if(MPI_FOUND AND MLX_BUILD_CPU)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
|
||||||
else()
|
|
||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp)
|
|
||||||
endif()
|
|
||||||
|
99
mlx/distributed/distributed.cpp
Normal file
99
mlx/distributed/distributed.cpp
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/distributed/distributed.h"
|
||||||
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
|
#include "mlx/distributed/mpi/mpi.h"
|
||||||
|
#include "mlx/scheduler.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed {
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
Stream communication_stream() {
|
||||||
|
static Stream comm_stream = new_stream(Device::cpu);
|
||||||
|
return comm_stream;
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_sum(Group group, const array& input, array& output) {
|
||||||
|
group.raw_group()->all_sum(input, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_gather(Group group, const array& input, array& output) {
|
||||||
|
group.raw_group()->all_gather(input, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
void send(Group group, const array& input, int dst) {
|
||||||
|
group.raw_group()->send(input, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
void recv(Group group, array& out, int src) {
|
||||||
|
group.raw_group()->recv(out, src);
|
||||||
|
}
|
||||||
|
|
||||||
|
class EmptyGroup : public GroupImpl {
|
||||||
|
public:
|
||||||
|
int rank() override {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int size() override {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||||
|
throw std::runtime_error("Cannot split the distributed group further.");
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_sum(const array& input, array& output) override {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Communication not implemented in an empty distributed group.");
|
||||||
|
}
|
||||||
|
void all_gather(const array& input, array& output) override {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Communication not implemented in an empty distributed group.");
|
||||||
|
}
|
||||||
|
void send(const array& input, int dst) override {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Communication not implemented in an empty distributed group.");
|
||||||
|
}
|
||||||
|
void recv(array& out, int src) override {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Communication not implemented in an empty distributed group.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return mpi::is_available();
|
||||||
|
}
|
||||||
|
|
||||||
|
int Group::rank() const {
|
||||||
|
return group_->rank();
|
||||||
|
}
|
||||||
|
|
||||||
|
int Group::size() const {
|
||||||
|
return group_->size();
|
||||||
|
}
|
||||||
|
|
||||||
|
Group Group::split(int color, int key /* = -1 */) const {
|
||||||
|
return Group(group_->split(color, key));
|
||||||
|
}
|
||||||
|
|
||||||
|
Group init(bool strict /* = false */) {
|
||||||
|
auto init_group = [strict]() {
|
||||||
|
auto default_group = mpi::init(strict);
|
||||||
|
if (default_group == nullptr) {
|
||||||
|
default_group = std::make_shared<detail::EmptyGroup>();
|
||||||
|
}
|
||||||
|
return default_group;
|
||||||
|
};
|
||||||
|
static std::shared_ptr<detail::GroupImpl> default_group = init_group();
|
||||||
|
|
||||||
|
// Ensure the communication stream is alive before
|
||||||
|
// the graph is evaluated
|
||||||
|
detail::communication_stream();
|
||||||
|
return Group(default_group);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed
|
@ -8,6 +8,11 @@
|
|||||||
|
|
||||||
namespace mlx::core::distributed {
|
namespace mlx::core::distributed {
|
||||||
|
|
||||||
|
// Forward declaration of the base group implementation.
|
||||||
|
namespace detail {
|
||||||
|
class GroupImpl;
|
||||||
|
};
|
||||||
|
|
||||||
/* Check if a communication backend is available */
|
/* Check if a communication backend is available */
|
||||||
bool is_available();
|
bool is_available();
|
||||||
|
|
||||||
@ -17,10 +22,10 @@ bool is_available();
|
|||||||
* order to define more granular communication.
|
* order to define more granular communication.
|
||||||
*/
|
*/
|
||||||
struct Group {
|
struct Group {
|
||||||
Group(std::shared_ptr<void> group) : group_(group) {}
|
Group(std::shared_ptr<detail::GroupImpl> group) : group_(std::move(group)) {}
|
||||||
|
|
||||||
int rank();
|
int rank() const;
|
||||||
int size();
|
int size() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Split the group according to the provided color. Namely processes that use
|
* Split the group according to the provided color. Namely processes that use
|
||||||
@ -30,14 +35,14 @@ struct Group {
|
|||||||
* the key the smaller the rank. If the provided key is negative, then the
|
* the key the smaller the rank. If the provided key is negative, then the
|
||||||
* rank in the current group is used.
|
* rank in the current group is used.
|
||||||
*/
|
*/
|
||||||
Group split(int color, int key = -1);
|
Group split(int color, int key = -1) const;
|
||||||
|
|
||||||
const std::shared_ptr<void>& raw_group() {
|
const std::shared_ptr<detail::GroupImpl>& raw_group() const {
|
||||||
return group_;
|
return group_;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<void> group_{nullptr};
|
std::shared_ptr<detail::GroupImpl> group_{nullptr};
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -6,6 +6,21 @@
|
|||||||
|
|
||||||
namespace mlx::core::distributed::detail {
|
namespace mlx::core::distributed::detail {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Abstract base class of a distributed group implementation.
|
||||||
|
*/
|
||||||
|
class GroupImpl {
|
||||||
|
public:
|
||||||
|
virtual int rank() = 0;
|
||||||
|
virtual int size() = 0;
|
||||||
|
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;
|
||||||
|
|
||||||
|
virtual void all_sum(const array& input, array& output) = 0;
|
||||||
|
virtual void all_gather(const array& input, array& output) = 0;
|
||||||
|
virtual void send(const array& input, int dst) = 0;
|
||||||
|
virtual void recv(array& out, int src) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
/* Return the communication stream. */
|
/* Return the communication stream. */
|
||||||
Stream communication_stream();
|
Stream communication_stream();
|
||||||
|
|
||||||
|
@ -1 +1,5 @@
|
|||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp)
|
if(MPI_FOUND AND MLX_BUILD_CPU)
|
||||||
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp)
|
||||||
|
else()
|
||||||
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_mpi.cpp)
|
||||||
|
endif()
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
#include "mlx/backend/common/copy.h"
|
#include "mlx/backend/common/copy.h"
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/distributed_impl.h"
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
|
#include "mlx/distributed/mpi/mpi.h"
|
||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
|
|
||||||
#define LOAD_SYMBOL(symbol, variable) \
|
#define LOAD_SYMBOL(symbol, variable) \
|
||||||
@ -18,7 +19,9 @@
|
|||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace mlx::core::distributed {
|
namespace mlx::core::distributed::mpi {
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@ -233,11 +236,14 @@ MPIWrapper& mpi() {
|
|||||||
return wrapper;
|
return wrapper;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct MPIGroupImpl {
|
} // namespace
|
||||||
MPIGroupImpl() : comm_(nullptr), global_(true), rank_(0), size_(1) {}
|
|
||||||
MPIGroupImpl(MPI_Comm comm, bool global)
|
class MPIGroup : public GroupImpl {
|
||||||
|
public:
|
||||||
|
MPIGroup(MPI_Comm comm, bool global)
|
||||||
: comm_(comm), global_(global), rank_(-1), size_(-1) {}
|
: comm_(comm), global_(global), rank_(-1), size_(-1) {}
|
||||||
~MPIGroupImpl() {
|
|
||||||
|
virtual ~MPIGroup() {
|
||||||
if (global_) {
|
if (global_) {
|
||||||
mpi().finalize_safe();
|
mpi().finalize_safe();
|
||||||
} else {
|
} else {
|
||||||
@ -245,24 +251,74 @@ struct MPIGroupImpl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MPI_Comm comm() {
|
int rank() override {
|
||||||
return comm_;
|
|
||||||
}
|
|
||||||
|
|
||||||
int rank() {
|
|
||||||
if (rank_ < 0) {
|
if (rank_ < 0) {
|
||||||
mpi().rank(comm_, &rank_);
|
mpi().rank(comm_, &rank_);
|
||||||
}
|
}
|
||||||
return rank_;
|
return rank_;
|
||||||
}
|
}
|
||||||
|
|
||||||
int size() {
|
int size() override {
|
||||||
if (size_ < 0) {
|
if (size_ < 0) {
|
||||||
mpi().size(comm_, &size_);
|
mpi().size(comm_, &size_);
|
||||||
}
|
}
|
||||||
return size_;
|
return size_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||||
|
key = (key < 0) ? rank() : key;
|
||||||
|
|
||||||
|
MPI_Comm new_comm;
|
||||||
|
int result = mpi().comm_split(comm_, color, key, &new_comm);
|
||||||
|
if (result != MPI_SUCCESS) {
|
||||||
|
throw std::runtime_error("MPI could not split this group");
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_shared<MPIGroup>(new_comm, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_sum(const array& input_, array& output) override {
|
||||||
|
array input = ensure_row_contiguous(input_);
|
||||||
|
mpi().all_reduce(
|
||||||
|
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE
|
||||||
|
: input.data<void>(),
|
||||||
|
output.data<void>(),
|
||||||
|
input.size(),
|
||||||
|
mpi().datatype(input),
|
||||||
|
mpi().op_sum(input),
|
||||||
|
comm_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_gather(const array& input_, array& output) override {
|
||||||
|
array input = ensure_row_contiguous(input_);
|
||||||
|
mpi().all_gather(
|
||||||
|
input.data<void>(),
|
||||||
|
input.size(),
|
||||||
|
mpi().datatype(input),
|
||||||
|
output.data<void>(),
|
||||||
|
input.size(),
|
||||||
|
mpi().datatype(output),
|
||||||
|
comm_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void send(const array& input_, int dst) override {
|
||||||
|
array input = ensure_row_contiguous(input_);
|
||||||
|
mpi().send(
|
||||||
|
input.data<void>(), input.size(), mpi().datatype(input), dst, 0, comm_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void recv(array& out, int src) override {
|
||||||
|
MPI_Status status;
|
||||||
|
mpi().recv(
|
||||||
|
out.data<void>(),
|
||||||
|
out.size(),
|
||||||
|
mpi().datatype(out),
|
||||||
|
src,
|
||||||
|
MPI_ANY_TAG,
|
||||||
|
comm_,
|
||||||
|
&status);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MPI_Comm comm_;
|
MPI_Comm comm_;
|
||||||
bool global_;
|
bool global_;
|
||||||
@ -270,112 +326,19 @@ struct MPIGroupImpl {
|
|||||||
int size_;
|
int size_;
|
||||||
};
|
};
|
||||||
|
|
||||||
MPI_Comm to_comm(Group& group) {
|
|
||||||
return std::static_pointer_cast<MPIGroupImpl>(group.raw_group())->comm();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
int Group::rank() {
|
|
||||||
return std::static_pointer_cast<MPIGroupImpl>(group_)->rank();
|
|
||||||
}
|
|
||||||
|
|
||||||
int Group::size() {
|
|
||||||
return std::static_pointer_cast<MPIGroupImpl>(group_)->size();
|
|
||||||
}
|
|
||||||
|
|
||||||
Group Group::split(int color, int key) {
|
|
||||||
auto mpi_group = std::static_pointer_cast<MPIGroupImpl>(group_);
|
|
||||||
|
|
||||||
key = (key < 0) ? rank() : key;
|
|
||||||
|
|
||||||
MPI_Comm new_comm;
|
|
||||||
int result = mpi().comm_split(mpi_group->comm(), color, key, &new_comm);
|
|
||||||
if (result != MPI_SUCCESS) {
|
|
||||||
throw std::runtime_error("MPI could not split this group");
|
|
||||||
}
|
|
||||||
|
|
||||||
return Group(std::make_shared<MPIGroupImpl>(new_comm, false));
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
return mpi().is_available();
|
return mpi().is_available();
|
||||||
}
|
}
|
||||||
|
|
||||||
Group init(bool strict /* = false */) {
|
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||||
static std::shared_ptr<MPIGroupImpl> global_group = nullptr;
|
if (!mpi().init_safe()) {
|
||||||
|
if (strict) {
|
||||||
if (global_group == nullptr) {
|
throw std::runtime_error("Cannot initialize MPI");
|
||||||
if (!mpi().init_safe()) {
|
|
||||||
if (strict) {
|
|
||||||
throw std::runtime_error("Cannot initialize MPI");
|
|
||||||
}
|
|
||||||
global_group = std::make_shared<MPIGroupImpl>();
|
|
||||||
} else {
|
|
||||||
global_group = std::make_shared<MPIGroupImpl>(mpi().world(), true);
|
|
||||||
}
|
}
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the communication stream is alive before
|
return std::make_shared<MPIGroup>(mpi().world(), true);
|
||||||
// the graph is evaluated
|
|
||||||
detail::communication_stream();
|
|
||||||
return Group(global_group);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace detail {
|
} // namespace mlx::core::distributed::mpi
|
||||||
|
|
||||||
Stream communication_stream() {
|
|
||||||
static Stream comm_stream = new_stream(Device::cpu);
|
|
||||||
return comm_stream;
|
|
||||||
}
|
|
||||||
|
|
||||||
void all_sum(Group group, const array& input_, array& output) {
|
|
||||||
array input = ensure_row_contiguous(input_);
|
|
||||||
mpi().all_reduce(
|
|
||||||
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE
|
|
||||||
: input.data<void>(),
|
|
||||||
output.data<void>(),
|
|
||||||
input.size(),
|
|
||||||
mpi().datatype(input),
|
|
||||||
mpi().op_sum(input),
|
|
||||||
to_comm(group));
|
|
||||||
}
|
|
||||||
|
|
||||||
void all_gather(Group group, const array& input_, array& output) {
|
|
||||||
array input = ensure_row_contiguous(input_);
|
|
||||||
mpi().all_gather(
|
|
||||||
input.data<void>(),
|
|
||||||
input.size(),
|
|
||||||
mpi().datatype(input),
|
|
||||||
output.data<void>(),
|
|
||||||
input.size(),
|
|
||||||
mpi().datatype(output),
|
|
||||||
to_comm(group));
|
|
||||||
}
|
|
||||||
|
|
||||||
void send(Group group, const array& input_, int dst) {
|
|
||||||
array input = ensure_row_contiguous(input_);
|
|
||||||
mpi().send(
|
|
||||||
input.data<void>(),
|
|
||||||
input.size(),
|
|
||||||
mpi().datatype(input),
|
|
||||||
dst,
|
|
||||||
0,
|
|
||||||
to_comm(group));
|
|
||||||
}
|
|
||||||
|
|
||||||
void recv(Group group, array& out, int src) {
|
|
||||||
MPI_Status status;
|
|
||||||
mpi().recv(
|
|
||||||
out.data<void>(),
|
|
||||||
out.size(),
|
|
||||||
mpi().datatype(out),
|
|
||||||
src,
|
|
||||||
MPI_ANY_TAG,
|
|
||||||
to_comm(group),
|
|
||||||
&status);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace detail
|
|
||||||
|
|
||||||
} // namespace mlx::core::distributed
|
|
||||||
|
12
mlx/distributed/mpi/mpi.h
Normal file
12
mlx/distributed/mpi/mpi.h
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/distributed/distributed.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::mpi {
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
|
||||||
|
bool is_available();
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict = false);
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::mpi
|
20
mlx/distributed/mpi/no_mpi.cpp
Normal file
20
mlx/distributed/mpi/no_mpi.cpp
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/distributed/mpi/mpi.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::mpi {
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||||
|
if (strict) {
|
||||||
|
throw std::runtime_error("Cannot initialize MPI");
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::mpi
|
@ -1,42 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/distributed/distributed.h"
|
|
||||||
#include "mlx/distributed/distributed_impl.h"
|
|
||||||
|
|
||||||
namespace mlx::core::distributed {
|
|
||||||
|
|
||||||
int Group::rank() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
int Group::size() {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
Group Group::split(int color, int key) {
|
|
||||||
throw std::runtime_error("Cannot split the distributed group further");
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_available() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
Group init(bool strict /* = false */) {
|
|
||||||
return Group(nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
Stream communication_stream() {
|
|
||||||
static Stream comm_stream = new_stream(Device::cpu);
|
|
||||||
return comm_stream;
|
|
||||||
}
|
|
||||||
|
|
||||||
void all_sum(Group group, const array& input, array& output) {}
|
|
||||||
void all_gather(Group group, const array& input, array& output) {}
|
|
||||||
void send(Group group, const array& input, int dst) {}
|
|
||||||
void recv(Group group, array& out, int src) {}
|
|
||||||
|
|
||||||
} // namespace detail
|
|
||||||
|
|
||||||
} // namespace mlx::core::distributed
|
|
Loading…
Reference in New Issue
Block a user