In place all-reduce and forgiving init (#1178)

This commit is contained in:
Angelos Katharopoulos 2024-06-03 16:47:47 -07:00 committed by GitHub
parent 4d485fca24
commit 3de8ce3f3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 31 additions and 7 deletions

View File

@ -43,8 +43,12 @@ struct Group {
/** /**
* Initialize the distributed backend and return the group containing all * Initialize the distributed backend and return the group containing all
* discoverable processes. * discoverable processes.
*
* If strict is true then throw an error if we couldn't initialize the
* distributed subsystem. Otherwise simply return a singleton group which will
* render communication operations as no-op.
*/ */
Group init(); Group init(bool strict = false);
namespace detail { namespace detail {

View File

@ -168,6 +168,7 @@ MPIWrapper& mpi() {
} }
struct MPIGroupImpl { struct MPIGroupImpl {
MPIGroupImpl() : comm_(nullptr), global_(true), rank_(0), size_(1) {}
MPIGroupImpl(MPI_Comm comm, bool global) MPIGroupImpl(MPI_Comm comm, bool global)
: comm_(comm), global_(global), rank_(-1), size_(-1) {} : comm_(comm), global_(global), rank_(-1), size_(-1) {}
~MPIGroupImpl() { ~MPIGroupImpl() {
@ -235,14 +236,18 @@ bool is_available() {
return mpi().is_available(); return mpi().is_available();
} }
Group init() { Group init(bool strict /* = false */) {
static std::shared_ptr<MPIGroupImpl> global_group = nullptr; static std::shared_ptr<MPIGroupImpl> global_group = nullptr;
if (global_group == nullptr) { if (global_group == nullptr) {
if (!mpi().init_safe()) { if (!mpi().init_safe()) {
throw std::runtime_error("Cannot initialize MPI"); if (strict) {
throw std::runtime_error("Cannot initialize MPI");
}
global_group = std::make_shared<MPIGroupImpl>();
} else {
global_group = std::make_shared<MPIGroupImpl>(mpi().world(), true);
} }
global_group = std::make_shared<MPIGroupImpl>(mpi().world(), true);
} }
return Group(global_group); return Group(global_group);
@ -258,7 +263,8 @@ Stream communication_stream() {
void all_reduce_sum(Group group, const array& input_, array& output) { void all_reduce_sum(Group group, const array& input_, array& output) {
array input = ensure_row_contiguous(input_); array input = ensure_row_contiguous(input_);
mpi().all_reduce( mpi().all_reduce(
input.data<void>(), (input.data<void>() == output.data<void>()) ? MPI_IN_PLACE
: input.data<void>(),
output.data<void>(), output.data<void>(),
input.size(), input.size(),
mpi().datatype(input), mpi().datatype(input),

View File

@ -20,7 +20,7 @@ bool is_available() {
return false; return false;
} }
Group init() { Group init(bool strict /* = false */) {
return Group(nullptr); return Group(nullptr);
} }

View File

@ -16,7 +16,11 @@ void AllReduce::eval_cpu(
assert(inputs.size() == 1); assert(inputs.size() == 1);
assert(outputs.size() == 1); assert(outputs.size() == 1);
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes())); if (inputs[0].is_donatable()) {
outputs[0].copy_shared_buffer(inputs[0]);
} else {
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
}
switch (reduce_type_) { switch (reduce_type_) {
case Sum: case Sum:

View File

@ -54,8 +54,18 @@ void init_distributed(nb::module_& parent_module) {
m.def( m.def(
"init", "init",
&distributed::init, &distributed::init,
"strict"_a = false,
nb::sig("def init(strict: bool = False) -> Group"),
R"pbdoc( R"pbdoc(
Initialize the communication backend and create the global communication group. Initialize the communication backend and create the global communication group.
Args:
strict (bool, optional): If set to False it returns a singleton group
in case ``mx.distributed.is_available()`` returns False otherwise
it throws a runtime error. Default: ``False``
Returns:
Group: The group representing all the launched processes.
)pbdoc"); )pbdoc");
m.def( m.def(