From 515f1049266fb3c9ed1ee469820885f61e75ced1 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Thu, 10 Apr 2025 08:22:20 +0200 Subject: [PATCH] Min / max reductions (#2041) --- mlx/backend/cpu/distributed.cpp | 8 ++- mlx/distributed/distributed.cpp | 18 +++++++ mlx/distributed/distributed_impl.h | 8 +++ mlx/distributed/mpi/mpi.cpp | 42 +++++++++++++++ mlx/distributed/ops.cpp | 34 ++++++++++++ mlx/distributed/ops.h | 10 ++++ mlx/distributed/primitives.cpp | 15 +++++- mlx/distributed/ring/ring.cpp | 77 ++++++++++++++++++++------- python/src/distributed.cpp | 57 ++++++++++++++++++++ python/tests/mpi_test_distributed.py | 16 ++++++ python/tests/ring_test_distributed.py | 18 ++++--- 11 files changed, 276 insertions(+), 27 deletions(-) diff --git a/mlx/backend/cpu/distributed.cpp b/mlx/backend/cpu/distributed.cpp index 1afa027a8..dd4d179ac 100644 --- a/mlx/backend/cpu/distributed.cpp +++ b/mlx/backend/cpu/distributed.cpp @@ -46,8 +46,14 @@ void AllReduce::eval_cpu( case Sum: distributed::detail::all_sum(group(), in, outputs[0], stream()); break; + case Max: + distributed::detail::all_max(group(), in, outputs[0], stream()); + break; + case Min: + distributed::detail::all_min(group(), in, outputs[0], stream()); + break; default: - throw std::runtime_error("Only all reduce sum is supported for now"); + throw std::runtime_error("Only all reduce sum, min and max are supported for now"); } } diff --git a/mlx/distributed/distributed.cpp b/mlx/distributed/distributed.cpp index 33f1d1320..cc01e6090 100644 --- a/mlx/distributed/distributed.cpp +++ b/mlx/distributed/distributed.cpp @@ -15,6 +15,14 @@ void all_sum(Group group, const array& input, array& output, Stream stream) { group.raw_group()->all_sum(input, output, stream); } +void all_max(Group group, const array& input, array& output, Stream stream) { + group.raw_group()->all_max(input, output, stream); +} + +void all_min(Group group, const array& input, array& output, Stream stream) { + group.raw_group()->all_min(input, output, stream); +} + void all_gather(Group group, const array& input, array& output, Stream stream) { group.raw_group()->all_gather(input, output, stream); } @@ -57,6 +65,16 @@ class EmptyGroup : public GroupImpl { throw std::runtime_error( "Communication not implemented in an empty distributed group."); } + + void all_max(const array&, array&, Stream) override { + throw std::runtime_error( + "Communication not implemented in an empty distributed group."); + } + + void all_min(const array&, array&, Stream) override { + throw std::runtime_error( + "Communication not implemented in an empty distributed group."); + } }; } // namespace detail diff --git a/mlx/distributed/distributed_impl.h b/mlx/distributed/distributed_impl.h index 7c06068b2..8b0327131 100644 --- a/mlx/distributed/distributed_impl.h +++ b/mlx/distributed/distributed_impl.h @@ -21,6 +21,8 @@ class GroupImpl { virtual void all_gather(const array& input, array& output, Stream stream) = 0; virtual void send(const array& input, int dst, Stream stream) = 0; virtual void recv(array& out, int src, Stream stream) = 0; + virtual void all_max(const array& input, array& output, Stream stream) = 0; + virtual void all_min(const array& input, array& output, Stream stream) = 0; }; /* Perform an all reduce sum operation */ @@ -35,4 +37,10 @@ void send(Group group, const array& input, int dst, Stream stream); /** Recv an array from the src rank */ void recv(Group group, array& out, int src, Stream stream); +/** Max reduction */ +void all_max(Group group, const array& input, array& output, Stream stream); + +/** Min reduction */ +void all_min(Group group, const array& input, array& output, Stream stream); + } // namespace mlx::core::distributed::detail diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index ed0d66b3f..f48009397 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -93,6 +93,8 @@ struct MPIWrapper { // Ops LOAD_SYMBOL(ompi_mpi_op_sum, op_sum_); + LOAD_SYMBOL(ompi_mpi_op_max, op_max_); + LOAD_SYMBOL(ompi_mpi_op_min, op_min_); // Datatypes LOAD_SYMBOL(ompi_mpi_c_bool, mpi_bool_); @@ -191,6 +193,14 @@ struct MPIWrapper { } } + MPI_Op op_max(const array& arr) { + return op_max_; + } + + MPI_Op op_min(const array& arr) { + return op_min_; + } + void* libmpi_handle_; // API @@ -219,6 +229,8 @@ struct MPIWrapper { MPI_Op op_sum_; MPI_Op op_sum_f16_; MPI_Op op_sum_bf16_; + MPI_Op op_max_; + MPI_Op op_min_; // Datatypes MPI_Datatype mpi_bool_; @@ -306,6 +318,36 @@ class MPIGroup : public GroupImpl { comm_); } + void all_max(const array& input, array& output, Stream stream) override { + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(input); + encoder.set_output_array(output); + encoder.dispatch( + mpi().all_reduce, + (input.data() == output.data()) ? MPI_IN_PLACE + : input.data(), + output.data(), + input.size(), + mpi().datatype(input), + mpi().op_max(input), + comm_); + } + + void all_min(const array& input, array& output, Stream stream) override { + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(input); + encoder.set_output_array(output); + encoder.dispatch( + mpi().all_reduce, + (input.data() == output.data()) ? MPI_IN_PLACE + : input.data(), + output.data(), + input.size(), + mpi().datatype(input), + mpi().op_min(input), + comm_); + } + void all_gather(const array& input, array& output, Stream stream) override { auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(input); diff --git a/mlx/distributed/ops.cpp b/mlx/distributed/ops.cpp index 865911ac6..0a5114805 100644 --- a/mlx/distributed/ops.cpp +++ b/mlx/distributed/ops.cpp @@ -36,6 +36,40 @@ array all_sum( {x}); } +array all_max( + const array& x, + std::optional group_ /* = std::nullopt */, + StreamOrDevice s /* = {} */) { + auto group = to_group(group_); + + if (group.size() == 1) { + return x; + } + return array( + x.shape(), + x.dtype(), + std::make_shared( + to_stream(s, Device::cpu), group, AllReduce::Max), + {x}); +} + +array all_min( + const array& x, + std::optional group_ /* = std::nullopt */, + StreamOrDevice s /* = {} */) { + auto group = to_group(group_); + + if (group.size() == 1) { + return x; + } + return array( + x.shape(), + x.dtype(), + std::make_shared( + to_stream(s, Device::cpu), group, AllReduce::Min), + {x}); +} + array all_gather( const array& x, std::optional group_ /* = std::nullopt */, diff --git a/mlx/distributed/ops.h b/mlx/distributed/ops.h index 9430106b1..edd1fc9f4 100644 --- a/mlx/distributed/ops.h +++ b/mlx/distributed/ops.h @@ -38,4 +38,14 @@ array recv_like( std::optional group = std::nullopt, StreamOrDevice s = {}); +array all_max( + const array& x, + std::optional group = std::nullopt, + StreamOrDevice s = {}); + +array all_min( + const array& x, + std::optional group = std::nullopt, + StreamOrDevice s = {}); + } // namespace mlx::core::distributed diff --git a/mlx/distributed/primitives.cpp b/mlx/distributed/primitives.cpp index de28788b3..576424cdd 100644 --- a/mlx/distributed/primitives.cpp +++ b/mlx/distributed/primitives.cpp @@ -15,8 +15,14 @@ std::pair, std::vector> AllReduce::vmap( switch (reduce_type_) { case Sum: return {{all_sum(inputs[0], group(), stream())}, axes}; + case Max: + return {{all_max(inputs[0], group(), stream())}, axes}; + case Min: + return {{all_min(inputs[0], group(), stream())}, axes}; default: - throw std::runtime_error("Only all reduce sum is supported for now"); + + throw std::runtime_error( + "Only all reduce sum, max and min are supported for now"); } } @@ -27,8 +33,13 @@ std::vector AllReduce::jvp( switch (reduce_type_) { case Sum: return {all_sum(tangents[0], group(), stream())}; + case Max: + return {all_max(tangents[0], group(), stream())}; + case Min: + return {all_min(tangents[0], group(), stream())}; default: - throw std::runtime_error("Only all reduce sum is supported for now"); + throw std::runtime_error( + "Only all reduce sum, max and min are supported for now"); } } diff --git a/mlx/distributed/ring/ring.cpp b/mlx/distributed/ring/ring.cpp index e55d960e7..b31274e23 100644 --- a/mlx/distributed/ring/ring.cpp +++ b/mlx/distributed/ring/ring.cpp @@ -503,15 +503,38 @@ std::vector make_connections( return sockets; } +template +struct SumOp { + void operator()(const T* input, T* output, size_t N) { + while (N-- > 0) { + *output += *input; + input++; + output++; + } + } +}; template -void sum_inplace(const T* input, T* output, size_t N) { - while (N-- > 0) { - *output += *input; - input++; - output++; +struct MaxOp { + void operator()(const T* input, T* output, size_t N) { + while (N-- > 0) { + *output = std::max(*output, *input); + input++; + output++; + } } -} +}; + +template +struct MinOp { + void operator()(const T* input, T* output, size_t N) { + while (N-- > 0) { + *output = std::min(*output, *input); + input++; + output++; + } + } +}; } // namespace @@ -605,7 +628,18 @@ class RingGroup : public GroupImpl { } void all_sum(const array& input, array& output, Stream stream) override { - SWITCH_TYPE(output, all_sum(input, output, stream)); + SWITCH_TYPE( + output, all_reduce>(input, output, stream, SumOp())); + } + + void all_max(const array& input, array& output, Stream stream) override { + SWITCH_TYPE( + output, all_reduce>(input, output, stream, MaxOp())); + } + + void all_min(const array& input, array& output, Stream stream) override { + SWITCH_TYPE( + output, all_reduce>(input, output, stream, MinOp())); } std::shared_ptr split(int color, int key = -1) override { @@ -694,13 +728,17 @@ class RingGroup : public GroupImpl { } private: - template - void all_sum(const array& input, array& output, Stream stream) { + template + void all_reduce( + const array& input, + array& output, + Stream stream, + ReduceOp reduce_op) { auto in_ptr = input.data(); auto out_ptr = output.data(); auto& encoder = cpu::get_command_encoder(stream); encoder.set_output_array(output); - encoder.dispatch([in_ptr, out_ptr, size = input.size(), this]() { + encoder.dispatch([in_ptr, out_ptr, size = input.size(), this, reduce_op]() { // If the input data cannot be split into size_ segments then copy it and // all reduce a local buffer prefilled with 0s. size_t nbytes = size * sizeof(T); @@ -717,13 +755,14 @@ class RingGroup : public GroupImpl { char buffer[1024]; std::memset(buffer, 0, size_ * sizeof(T)); std::memcpy(buffer, in_ptr, nbytes); - all_sum_impl( + all_reduce_impl( reinterpret_cast(buffers_.data()), reinterpret_cast(buffer), size_, sockets_right_[0], sockets_left_[0], - -1); + -1, + reduce_op); std::memcpy(out_ptr, buffer, nbytes); return; } @@ -746,7 +785,7 @@ class RingGroup : public GroupImpl { for (int i = 0; i < n_reduces; i++) { all_sums.emplace_back(pool_.enqueue(std::bind( - &RingGroup::all_sum_impl, + &RingGroup::all_reduce_impl, this, reinterpret_cast( buffers_.data() + i * ALL_SUM_SIZE * ALL_SUM_BUFFERS), @@ -754,7 +793,8 @@ class RingGroup : public GroupImpl { std::min(size, (i + 1) * step) - i * step, sockets_right_[i / 2], sockets_left_[i / 2], - (i % 2) ? -1 : 1))); + (i % 2) ? -1 : 1, + reduce_op))); } for (auto& f : all_sums) { f.wait(); @@ -762,14 +802,15 @@ class RingGroup : public GroupImpl { }); } - template - void all_sum_impl( + template + void all_reduce_impl( T* buffer, T* data, size_t data_size, int socket_right, int socket_left, - int direction) { + int direction, + ReduceOp reduce_op) { // Choose which socket we send to and recv from int socket_send = (direction < 0) ? socket_right : socket_left; int socket_recv = (direction < 0) ? socket_left : socket_right; @@ -846,7 +887,7 @@ class RingGroup : public GroupImpl { sends[b].wait(); recvs[b].wait(); if (2 * j < send_plan.size()) { - sum_inplace( + reduce_op( recv_buffers[j % ALL_SUM_BUFFERS], data + recv_plan[j].first, recv_plan[j].second - recv_plan[j].first); diff --git a/python/src/distributed.cpp b/python/src/distributed.cpp index ff24b8a95..c9acc8583 100644 --- a/python/src/distributed.cpp +++ b/python/src/distributed.cpp @@ -117,7 +117,64 @@ void init_distributed(nb::module_& parent_module) { Returns: array: The sum of all ``x`` arrays. )pbdoc"); + m.def( + "all_max", + [](const ScalarOrArray& x, + std::optional group, + mx::StreamOrDevice s) { + return mx::distributed::all_max(to_array(x), group, s); + }, + "x"_a, + nb::kw_only(), + "group"_a = nb::none(), + "stream"_a = nb::none(), + nb::sig( + "def all_max(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + All reduce max. + Find the maximum of the ``x`` arrays from all processes in the group. + + Args: + x (array): Input array. + group (Group): The group of processes that will participate in the + reduction. If set to ``None`` the global group is used. Default: + ``None``. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: The maximum of all ``x`` arrays. + )pbdoc"); + m.def( + "all_min", + [](const ScalarOrArray& x, + std::optional group, + mx::StreamOrDevice s) { + return mx::distributed::all_min(to_array(x), group, s); + }, + "x"_a, + nb::kw_only(), + "group"_a = nb::none(), + "stream"_a = nb::none(), + nb::sig( + "def all_min(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + All reduce min. + + Find the minimum of the ``x`` arrays from all processes in the group. + + Args: + x (array): Input array. + group (Group): The group of processes that will participate in the + reduction. If set to ``None`` the global group is used. Default: + ``None``. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: The minimum of all ``x`` arrays. + )pbdoc"); m.def( "all_gather", [](const ScalarOrArray& x, diff --git a/python/tests/mpi_test_distributed.py b/python/tests/mpi_test_distributed.py index ebc8ad728..65fbd09ce 100644 --- a/python/tests/mpi_test_distributed.py +++ b/python/tests/mpi_test_distributed.py @@ -124,6 +124,22 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): x = mx.distributed.recv_like(x, neighbor, group=pairs) mx.eval(y, x) + def test_min_max(self): + world = mx.distributed.init() + base = mx.arange(16).reshape(4, 4) + x = base + world.rank() * 32 + + def _test_reduction(reduction: str = "all_max"): + + target = base + ((world.size() - 1) * 16) * (reduction == "max") + reducer = getattr(mx.distributed, reduction) + y = reducer(x) + + self.assertTrue(mx.allclose(y, target)) + + for reduction in ["all_max", "all_min"]: + _test_reduction(reduction) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/ring_test_distributed.py b/python/tests/ring_test_distributed.py index 169889559..d74c534e0 100644 --- a/python/tests/ring_test_distributed.py +++ b/python/tests/ring_test_distributed.py @@ -44,17 +44,23 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): (1024, 1024), ] key = mx.random.key(0) + reductions = ["min", "max", "sum"] + for dt, rtol in dtypes: for sh in sizes: x = ( mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10 ).astype(dt) - y = mx.distributed.all_sum(x[world.rank()]) - z = sum( - x[i] for i in range(world.size()) - ) # to ensure that we don't sum to int32 - maxrelerror = ((y - z).abs() / z.abs()).max() - self.assertLessEqual(maxrelerror, rtol) + for reduction in reductions: + reducer_distributed = getattr(mx.distributed, f"all_{reduction}") + y = reducer_distributed(x[world.rank()]) + + reducer = getattr(mx, reduction) + z = reducer(x, axis=0) + mx.eval(y, z) + + maxrelerror = ((y - z).abs() / z.abs()).max() + self.assertLessEqual(maxrelerror, rtol) def test_all_gather(self): world = mx.distributed.init()