mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
In place all-reduce and forgiving init (#1178)
This commit is contained in:
parent
4d485fca24
commit
3de8ce3f3c
@ -43,8 +43,12 @@ struct Group {
|
|||||||
/**
|
/**
|
||||||
* Initialize the distributed backend and return the group containing all
|
* Initialize the distributed backend and return the group containing all
|
||||||
* discoverable processes.
|
* discoverable processes.
|
||||||
|
*
|
||||||
|
* If strict is true then throw an error if we couldn't initialize the
|
||||||
|
* distributed subsystem. Otherwise simply return a singleton group which will
|
||||||
|
* render communication operations as no-op.
|
||||||
*/
|
*/
|
||||||
Group init();
|
Group init(bool strict = false);
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
|
|
||||||
|
@ -168,6 +168,7 @@ MPIWrapper& mpi() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct MPIGroupImpl {
|
struct MPIGroupImpl {
|
||||||
|
MPIGroupImpl() : comm_(nullptr), global_(true), rank_(0), size_(1) {}
|
||||||
MPIGroupImpl(MPI_Comm comm, bool global)
|
MPIGroupImpl(MPI_Comm comm, bool global)
|
||||||
: comm_(comm), global_(global), rank_(-1), size_(-1) {}
|
: comm_(comm), global_(global), rank_(-1), size_(-1) {}
|
||||||
~MPIGroupImpl() {
|
~MPIGroupImpl() {
|
||||||
@ -235,14 +236,18 @@ bool is_available() {
|
|||||||
return mpi().is_available();
|
return mpi().is_available();
|
||||||
}
|
}
|
||||||
|
|
||||||
Group init() {
|
Group init(bool strict /* = false */) {
|
||||||
static std::shared_ptr<MPIGroupImpl> global_group = nullptr;
|
static std::shared_ptr<MPIGroupImpl> global_group = nullptr;
|
||||||
|
|
||||||
if (global_group == nullptr) {
|
if (global_group == nullptr) {
|
||||||
if (!mpi().init_safe()) {
|
if (!mpi().init_safe()) {
|
||||||
throw std::runtime_error("Cannot initialize MPI");
|
if (strict) {
|
||||||
|
throw std::runtime_error("Cannot initialize MPI");
|
||||||
|
}
|
||||||
|
global_group = std::make_shared<MPIGroupImpl>();
|
||||||
|
} else {
|
||||||
|
global_group = std::make_shared<MPIGroupImpl>(mpi().world(), true);
|
||||||
}
|
}
|
||||||
global_group = std::make_shared<MPIGroupImpl>(mpi().world(), true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return Group(global_group);
|
return Group(global_group);
|
||||||
@ -258,7 +263,8 @@ Stream communication_stream() {
|
|||||||
void all_reduce_sum(Group group, const array& input_, array& output) {
|
void all_reduce_sum(Group group, const array& input_, array& output) {
|
||||||
array input = ensure_row_contiguous(input_);
|
array input = ensure_row_contiguous(input_);
|
||||||
mpi().all_reduce(
|
mpi().all_reduce(
|
||||||
input.data<void>(),
|
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE
|
||||||
|
: input.data<void>(),
|
||||||
output.data<void>(),
|
output.data<void>(),
|
||||||
input.size(),
|
input.size(),
|
||||||
mpi().datatype(input),
|
mpi().datatype(input),
|
||||||
|
@ -20,7 +20,7 @@ bool is_available() {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
Group init() {
|
Group init(bool strict /* = false */) {
|
||||||
return Group(nullptr);
|
return Group(nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,7 +16,11 @@ void AllReduce::eval_cpu(
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
assert(outputs.size() == 1);
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
|
if (inputs[0].is_donatable()) {
|
||||||
|
outputs[0].copy_shared_buffer(inputs[0]);
|
||||||
|
} else {
|
||||||
|
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
|
||||||
|
}
|
||||||
|
|
||||||
switch (reduce_type_) {
|
switch (reduce_type_) {
|
||||||
case Sum:
|
case Sum:
|
||||||
|
@ -54,8 +54,18 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
m.def(
|
m.def(
|
||||||
"init",
|
"init",
|
||||||
&distributed::init,
|
&distributed::init,
|
||||||
|
"strict"_a = false,
|
||||||
|
nb::sig("def init(strict: bool = False) -> Group"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Initialize the communication backend and create the global communication group.
|
Initialize the communication backend and create the global communication group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strict (bool, optional): If set to False it returns a singleton group
|
||||||
|
in case ``mx.distributed.is_available()`` returns False otherwise
|
||||||
|
it throws a runtime error. Default: ``False``
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Group: The group representing all the launched processes.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
|
Loading…
Reference in New Issue
Block a user