mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
In place all-reduce and forgiving init (#1178)
This commit is contained in:
parent
4d485fca24
commit
3de8ce3f3c
@ -43,8 +43,12 @@ struct Group {
|
||||
/**
|
||||
* Initialize the distributed backend and return the group containing all
|
||||
* discoverable processes.
|
||||
*
|
||||
* If strict is true then throw an error if we couldn't initialize the
|
||||
* distributed subsystem. Otherwise simply return a singleton group which will
|
||||
* render communication operations as no-op.
|
||||
*/
|
||||
Group init();
|
||||
Group init(bool strict = false);
|
||||
|
||||
namespace detail {
|
||||
|
||||
|
@ -168,6 +168,7 @@ MPIWrapper& mpi() {
|
||||
}
|
||||
|
||||
struct MPIGroupImpl {
|
||||
MPIGroupImpl() : comm_(nullptr), global_(true), rank_(0), size_(1) {}
|
||||
MPIGroupImpl(MPI_Comm comm, bool global)
|
||||
: comm_(comm), global_(global), rank_(-1), size_(-1) {}
|
||||
~MPIGroupImpl() {
|
||||
@ -235,15 +236,19 @@ bool is_available() {
|
||||
return mpi().is_available();
|
||||
}
|
||||
|
||||
Group init() {
|
||||
Group init(bool strict /* = false */) {
|
||||
static std::shared_ptr<MPIGroupImpl> global_group = nullptr;
|
||||
|
||||
if (global_group == nullptr) {
|
||||
if (!mpi().init_safe()) {
|
||||
if (strict) {
|
||||
throw std::runtime_error("Cannot initialize MPI");
|
||||
}
|
||||
global_group = std::make_shared<MPIGroupImpl>();
|
||||
} else {
|
||||
global_group = std::make_shared<MPIGroupImpl>(mpi().world(), true);
|
||||
}
|
||||
}
|
||||
|
||||
return Group(global_group);
|
||||
}
|
||||
@ -258,7 +263,8 @@ Stream communication_stream() {
|
||||
void all_reduce_sum(Group group, const array& input_, array& output) {
|
||||
array input = ensure_row_contiguous(input_);
|
||||
mpi().all_reduce(
|
||||
input.data<void>(),
|
||||
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE
|
||||
: input.data<void>(),
|
||||
output.data<void>(),
|
||||
input.size(),
|
||||
mpi().datatype(input),
|
||||
|
@ -20,7 +20,7 @@ bool is_available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
Group init() {
|
||||
Group init(bool strict /* = false */) {
|
||||
return Group(nullptr);
|
||||
}
|
||||
|
||||
|
@ -16,7 +16,11 @@ void AllReduce::eval_cpu(
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
if (inputs[0].is_donatable()) {
|
||||
outputs[0].copy_shared_buffer(inputs[0]);
|
||||
} else {
|
||||
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
|
||||
}
|
||||
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
|
@ -54,8 +54,18 @@ void init_distributed(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"init",
|
||||
&distributed::init,
|
||||
"strict"_a = false,
|
||||
nb::sig("def init(strict: bool = False) -> Group"),
|
||||
R"pbdoc(
|
||||
Initialize the communication backend and create the global communication group.
|
||||
|
||||
Args:
|
||||
strict (bool, optional): If set to False it returns a singleton group
|
||||
in case ``mx.distributed.is_available()`` returns False otherwise
|
||||
it throws a runtime error. Default: ``False``
|
||||
|
||||
Returns:
|
||||
Group: The group representing all the launched processes.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
|
Loading…
Reference in New Issue
Block a user