From 3de8ce3f3c81612962835b00cfa7e461a12b35b2 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 3 Jun 2024 16:47:47 -0700 Subject: [PATCH] In place all-reduce and forgiving init (#1178) --- mlx/distributed/distributed.h | 6 +++++- mlx/distributed/mpi/mpi.cpp | 14 ++++++++++---- mlx/distributed/no_distributed.cpp | 2 +- mlx/distributed/primitives.cpp | 6 +++++- python/src/distributed.cpp | 10 ++++++++++ 5 files changed, 31 insertions(+), 7 deletions(-) diff --git a/mlx/distributed/distributed.h b/mlx/distributed/distributed.h index cad75b396..e2bd0771c 100644 --- a/mlx/distributed/distributed.h +++ b/mlx/distributed/distributed.h @@ -43,8 +43,12 @@ struct Group { /** * Initialize the distributed backend and return the group containing all * discoverable processes. + * + * If strict is true then throw an error if we couldn't initialize the + * distributed subsystem. Otherwise simply return a singleton group which will + * render communication operations as no-op. */ -Group init(); +Group init(bool strict = false); namespace detail { diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index 3d1818195..8e9b6caa3 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -168,6 +168,7 @@ MPIWrapper& mpi() { } struct MPIGroupImpl { + MPIGroupImpl() : comm_(nullptr), global_(true), rank_(0), size_(1) {} MPIGroupImpl(MPI_Comm comm, bool global) : comm_(comm), global_(global), rank_(-1), size_(-1) {} ~MPIGroupImpl() { @@ -235,14 +236,18 @@ bool is_available() { return mpi().is_available(); } -Group init() { +Group init(bool strict /* = false */) { static std::shared_ptr global_group = nullptr; if (global_group == nullptr) { if (!mpi().init_safe()) { - throw std::runtime_error("Cannot initialize MPI"); + if (strict) { + throw std::runtime_error("Cannot initialize MPI"); + } + global_group = std::make_shared(); + } else { + global_group = std::make_shared(mpi().world(), true); } - global_group = std::make_shared(mpi().world(), true); } return Group(global_group); @@ -258,7 +263,8 @@ Stream communication_stream() { void all_reduce_sum(Group group, const array& input_, array& output) { array input = ensure_row_contiguous(input_); mpi().all_reduce( - input.data(), + (input.data() == output.data()) ? MPI_IN_PLACE + : input.data(), output.data(), input.size(), mpi().datatype(input), diff --git a/mlx/distributed/no_distributed.cpp b/mlx/distributed/no_distributed.cpp index d85428496..df889d9df 100644 --- a/mlx/distributed/no_distributed.cpp +++ b/mlx/distributed/no_distributed.cpp @@ -20,7 +20,7 @@ bool is_available() { return false; } -Group init() { +Group init(bool strict /* = false */) { return Group(nullptr); } diff --git a/mlx/distributed/primitives.cpp b/mlx/distributed/primitives.cpp index b20fde605..91e230b6c 100644 --- a/mlx/distributed/primitives.cpp +++ b/mlx/distributed/primitives.cpp @@ -16,7 +16,11 @@ void AllReduce::eval_cpu( assert(inputs.size() == 1); assert(outputs.size() == 1); - outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes())); + if (inputs[0].is_donatable()) { + outputs[0].copy_shared_buffer(inputs[0]); + } else { + outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes())); + } switch (reduce_type_) { case Sum: diff --git a/python/src/distributed.cpp b/python/src/distributed.cpp index 069b5a885..5b01ffed3 100644 --- a/python/src/distributed.cpp +++ b/python/src/distributed.cpp @@ -54,8 +54,18 @@ void init_distributed(nb::module_& parent_module) { m.def( "init", &distributed::init, + "strict"_a = false, + nb::sig("def init(strict: bool = False) -> Group"), R"pbdoc( Initialize the communication backend and create the global communication group. + + Args: + strict (bool, optional): If set to False it returns a singleton group + in case ``mx.distributed.is_available()`` returns False otherwise + it throws a runtime error. Default: ``False`` + + Returns: + Group: The group representing all the launched processes. )pbdoc"); m.def(