Min / max reductions (#2041)

This commit is contained in:
Anastasiia Filippova 2025-04-10 08:22:20 +02:00 committed by GitHub
parent 9ecefd56db
commit 515f104926
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 276 additions and 27 deletions

View File

@ -46,8 +46,14 @@ void AllReduce::eval_cpu(
case Sum:
distributed::detail::all_sum(group(), in, outputs[0], stream());
break;
case Max:
distributed::detail::all_max(group(), in, outputs[0], stream());
break;
case Min:
distributed::detail::all_min(group(), in, outputs[0], stream());
break;
default:
throw std::runtime_error("Only all reduce sum is supported for now");
throw std::runtime_error("Only all reduce sum, min and max are supported for now");
}
}

View File

@ -15,6 +15,14 @@ void all_sum(Group group, const array& input, array& output, Stream stream) {
group.raw_group()->all_sum(input, output, stream);
}
void all_max(Group group, const array& input, array& output, Stream stream) {
group.raw_group()->all_max(input, output, stream);
}
void all_min(Group group, const array& input, array& output, Stream stream) {
group.raw_group()->all_min(input, output, stream);
}
void all_gather(Group group, const array& input, array& output, Stream stream) {
group.raw_group()->all_gather(input, output, stream);
}
@ -57,6 +65,16 @@ class EmptyGroup : public GroupImpl {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
void all_max(const array&, array&, Stream) override {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
void all_min(const array&, array&, Stream) override {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
};
} // namespace detail

View File

@ -21,6 +21,8 @@ class GroupImpl {
virtual void all_gather(const array& input, array& output, Stream stream) = 0;
virtual void send(const array& input, int dst, Stream stream) = 0;
virtual void recv(array& out, int src, Stream stream) = 0;
virtual void all_max(const array& input, array& output, Stream stream) = 0;
virtual void all_min(const array& input, array& output, Stream stream) = 0;
};
/* Perform an all reduce sum operation */
@ -35,4 +37,10 @@ void send(Group group, const array& input, int dst, Stream stream);
/** Recv an array from the src rank */
void recv(Group group, array& out, int src, Stream stream);
/** Max reduction */
void all_max(Group group, const array& input, array& output, Stream stream);
/** Min reduction */
void all_min(Group group, const array& input, array& output, Stream stream);
} // namespace mlx::core::distributed::detail

View File

@ -93,6 +93,8 @@ struct MPIWrapper {
// Ops
LOAD_SYMBOL(ompi_mpi_op_sum, op_sum_);
LOAD_SYMBOL(ompi_mpi_op_max, op_max_);
LOAD_SYMBOL(ompi_mpi_op_min, op_min_);
// Datatypes
LOAD_SYMBOL(ompi_mpi_c_bool, mpi_bool_);
@ -191,6 +193,14 @@ struct MPIWrapper {
}
}
MPI_Op op_max(const array& arr) {
return op_max_;
}
MPI_Op op_min(const array& arr) {
return op_min_;
}
void* libmpi_handle_;
// API
@ -219,6 +229,8 @@ struct MPIWrapper {
MPI_Op op_sum_;
MPI_Op op_sum_f16_;
MPI_Op op_sum_bf16_;
MPI_Op op_max_;
MPI_Op op_min_;
// Datatypes
MPI_Datatype mpi_bool_;
@ -306,6 +318,36 @@ class MPIGroup : public GroupImpl {
comm_);
}
void all_max(const array& input, array& output, Stream stream) override {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(input);
encoder.set_output_array(output);
encoder.dispatch(
mpi().all_reduce,
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE
: input.data<void>(),
output.data<void>(),
input.size(),
mpi().datatype(input),
mpi().op_max(input),
comm_);
}
void all_min(const array& input, array& output, Stream stream) override {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(input);
encoder.set_output_array(output);
encoder.dispatch(
mpi().all_reduce,
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE
: input.data<void>(),
output.data<void>(),
input.size(),
mpi().datatype(input),
mpi().op_min(input),
comm_);
}
void all_gather(const array& input, array& output, Stream stream) override {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(input);

View File

@ -36,6 +36,40 @@ array all_sum(
{x});
}
array all_max(
const array& x,
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
if (group.size() == 1) {
return x;
}
return array(
x.shape(),
x.dtype(),
std::make_shared<AllReduce>(
to_stream(s, Device::cpu), group, AllReduce::Max),
{x});
}
array all_min(
const array& x,
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
if (group.size() == 1) {
return x;
}
return array(
x.shape(),
x.dtype(),
std::make_shared<AllReduce>(
to_stream(s, Device::cpu), group, AllReduce::Min),
{x});
}
array all_gather(
const array& x,
std::optional<Group> group_ /* = std::nullopt */,

View File

@ -38,4 +38,14 @@ array recv_like(
std::optional<Group> group = std::nullopt,
StreamOrDevice s = {});
array all_max(
const array& x,
std::optional<Group> group = std::nullopt,
StreamOrDevice s = {});
array all_min(
const array& x,
std::optional<Group> group = std::nullopt,
StreamOrDevice s = {});
} // namespace mlx::core::distributed

View File

@ -15,8 +15,14 @@ std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
switch (reduce_type_) {
case Sum:
return {{all_sum(inputs[0], group(), stream())}, axes};
case Max:
return {{all_max(inputs[0], group(), stream())}, axes};
case Min:
return {{all_min(inputs[0], group(), stream())}, axes};
default:
throw std::runtime_error("Only all reduce sum is supported for now");
throw std::runtime_error(
"Only all reduce sum, max and min are supported for now");
}
}
@ -27,8 +33,13 @@ std::vector<array> AllReduce::jvp(
switch (reduce_type_) {
case Sum:
return {all_sum(tangents[0], group(), stream())};
case Max:
return {all_max(tangents[0], group(), stream())};
case Min:
return {all_min(tangents[0], group(), stream())};
default:
throw std::runtime_error("Only all reduce sum is supported for now");
throw std::runtime_error(
"Only all reduce sum, max and min are supported for now");
}
}

View File

@ -503,15 +503,38 @@ std::vector<int> make_connections(
return sockets;
}
template <typename T>
struct SumOp {
void operator()(const T* input, T* output, size_t N) {
while (N-- > 0) {
*output += *input;
input++;
output++;
}
}
};
template <typename T>
void sum_inplace(const T* input, T* output, size_t N) {
while (N-- > 0) {
*output += *input;
input++;
output++;
struct MaxOp {
void operator()(const T* input, T* output, size_t N) {
while (N-- > 0) {
*output = std::max(*output, *input);
input++;
output++;
}
}
}
};
template <typename T>
struct MinOp {
void operator()(const T* input, T* output, size_t N) {
while (N-- > 0) {
*output = std::min(*output, *input);
input++;
output++;
}
}
};
} // namespace
@ -605,7 +628,18 @@ class RingGroup : public GroupImpl {
}
void all_sum(const array& input, array& output, Stream stream) override {
SWITCH_TYPE(output, all_sum<T>(input, output, stream));
SWITCH_TYPE(
output, all_reduce<T, SumOp<T>>(input, output, stream, SumOp<T>()));
}
void all_max(const array& input, array& output, Stream stream) override {
SWITCH_TYPE(
output, all_reduce<T, MaxOp<T>>(input, output, stream, MaxOp<T>()));
}
void all_min(const array& input, array& output, Stream stream) override {
SWITCH_TYPE(
output, all_reduce<T, MinOp<T>>(input, output, stream, MinOp<T>()));
}
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
@ -694,13 +728,17 @@ class RingGroup : public GroupImpl {
}
private:
template <typename T>
void all_sum(const array& input, array& output, Stream stream) {
template <typename T, typename ReduceOp>
void all_reduce(
const array& input,
array& output,
Stream stream,
ReduceOp reduce_op) {
auto in_ptr = input.data<char>();
auto out_ptr = output.data<char>();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(output);
encoder.dispatch([in_ptr, out_ptr, size = input.size(), this]() {
encoder.dispatch([in_ptr, out_ptr, size = input.size(), this, reduce_op]() {
// If the input data cannot be split into size_ segments then copy it and
// all reduce a local buffer prefilled with 0s.
size_t nbytes = size * sizeof(T);
@ -717,13 +755,14 @@ class RingGroup : public GroupImpl {
char buffer[1024];
std::memset(buffer, 0, size_ * sizeof(T));
std::memcpy(buffer, in_ptr, nbytes);
all_sum_impl<T>(
all_reduce_impl<T, ReduceOp>(
reinterpret_cast<T*>(buffers_.data()),
reinterpret_cast<T*>(buffer),
size_,
sockets_right_[0],
sockets_left_[0],
-1);
-1,
reduce_op);
std::memcpy(out_ptr, buffer, nbytes);
return;
}
@ -746,7 +785,7 @@ class RingGroup : public GroupImpl {
for (int i = 0; i < n_reduces; i++) {
all_sums.emplace_back(pool_.enqueue(std::bind(
&RingGroup::all_sum_impl<T>,
&RingGroup::all_reduce_impl<T, ReduceOp>,
this,
reinterpret_cast<T*>(
buffers_.data() + i * ALL_SUM_SIZE * ALL_SUM_BUFFERS),
@ -754,7 +793,8 @@ class RingGroup : public GroupImpl {
std::min(size, (i + 1) * step) - i * step,
sockets_right_[i / 2],
sockets_left_[i / 2],
(i % 2) ? -1 : 1)));
(i % 2) ? -1 : 1,
reduce_op)));
}
for (auto& f : all_sums) {
f.wait();
@ -762,14 +802,15 @@ class RingGroup : public GroupImpl {
});
}
template <typename T>
void all_sum_impl(
template <typename T, typename ReduceOp>
void all_reduce_impl(
T* buffer,
T* data,
size_t data_size,
int socket_right,
int socket_left,
int direction) {
int direction,
ReduceOp reduce_op) {
// Choose which socket we send to and recv from
int socket_send = (direction < 0) ? socket_right : socket_left;
int socket_recv = (direction < 0) ? socket_left : socket_right;
@ -846,7 +887,7 @@ class RingGroup : public GroupImpl {
sends[b].wait();
recvs[b].wait();
if (2 * j < send_plan.size()) {
sum_inplace<T>(
reduce_op(
recv_buffers[j % ALL_SUM_BUFFERS],
data + recv_plan[j].first,
recv_plan[j].second - recv_plan[j].first);

View File

@ -117,7 +117,64 @@ void init_distributed(nb::module_& parent_module) {
Returns:
array: The sum of all ``x`` arrays.
)pbdoc");
m.def(
"all_max",
[](const ScalarOrArray& x,
std::optional<mx::distributed::Group> group,
mx::StreamOrDevice s) {
return mx::distributed::all_max(to_array(x), group, s);
},
"x"_a,
nb::kw_only(),
"group"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def all_max(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
All reduce max.
Find the maximum of the ``x`` arrays from all processes in the group.
Args:
x (array): Input array.
group (Group): The group of processes that will participate in the
reduction. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The maximum of all ``x`` arrays.
)pbdoc");
m.def(
"all_min",
[](const ScalarOrArray& x,
std::optional<mx::distributed::Group> group,
mx::StreamOrDevice s) {
return mx::distributed::all_min(to_array(x), group, s);
},
"x"_a,
nb::kw_only(),
"group"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def all_min(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
All reduce min.
Find the minimum of the ``x`` arrays from all processes in the group.
Args:
x (array): Input array.
group (Group): The group of processes that will participate in the
reduction. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The minimum of all ``x`` arrays.
)pbdoc");
m.def(
"all_gather",
[](const ScalarOrArray& x,

View File

@ -124,6 +124,22 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
x = mx.distributed.recv_like(x, neighbor, group=pairs)
mx.eval(y, x)
def test_min_max(self):
world = mx.distributed.init()
base = mx.arange(16).reshape(4, 4)
x = base + world.rank() * 32
def _test_reduction(reduction: str = "all_max"):
target = base + ((world.size() - 1) * 16) * (reduction == "max")
reducer = getattr(mx.distributed, reduction)
y = reducer(x)
self.assertTrue(mx.allclose(y, target))
for reduction in ["all_max", "all_min"]:
_test_reduction(reduction)
if __name__ == "__main__":
unittest.main()

View File

@ -44,17 +44,23 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
(1024, 1024),
]
key = mx.random.key(0)
reductions = ["min", "max", "sum"]
for dt, rtol in dtypes:
for sh in sizes:
x = (
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
).astype(dt)
y = mx.distributed.all_sum(x[world.rank()])
z = sum(
x[i] for i in range(world.size())
) # to ensure that we don't sum to int32
maxrelerror = ((y - z).abs() / z.abs()).max()
self.assertLessEqual(maxrelerror, rtol)
for reduction in reductions:
reducer_distributed = getattr(mx.distributed, f"all_{reduction}")
y = reducer_distributed(x[world.rank()])
reducer = getattr(mx, reduction)
z = reducer(x, axis=0)
mx.eval(y, z)
maxrelerror = ((y - z).abs() / z.abs()).max()
self.assertLessEqual(maxrelerror, rtol)
def test_all_gather(self):
world = mx.distributed.init()