mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-25 03:31:17 +08:00
Min / max reductions (#2041)
This commit is contained in:
parent
9ecefd56db
commit
515f104926
@ -46,8 +46,14 @@ void AllReduce::eval_cpu(
|
||||
case Sum:
|
||||
distributed::detail::all_sum(group(), in, outputs[0], stream());
|
||||
break;
|
||||
case Max:
|
||||
distributed::detail::all_max(group(), in, outputs[0], stream());
|
||||
break;
|
||||
case Min:
|
||||
distributed::detail::all_min(group(), in, outputs[0], stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
throw std::runtime_error("Only all reduce sum, min and max are supported for now");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -15,6 +15,14 @@ void all_sum(Group group, const array& input, array& output, Stream stream) {
|
||||
group.raw_group()->all_sum(input, output, stream);
|
||||
}
|
||||
|
||||
void all_max(Group group, const array& input, array& output, Stream stream) {
|
||||
group.raw_group()->all_max(input, output, stream);
|
||||
}
|
||||
|
||||
void all_min(Group group, const array& input, array& output, Stream stream) {
|
||||
group.raw_group()->all_min(input, output, stream);
|
||||
}
|
||||
|
||||
void all_gather(Group group, const array& input, array& output, Stream stream) {
|
||||
group.raw_group()->all_gather(input, output, stream);
|
||||
}
|
||||
@ -57,6 +65,16 @@ class EmptyGroup : public GroupImpl {
|
||||
throw std::runtime_error(
|
||||
"Communication not implemented in an empty distributed group.");
|
||||
}
|
||||
|
||||
void all_max(const array&, array&, Stream) override {
|
||||
throw std::runtime_error(
|
||||
"Communication not implemented in an empty distributed group.");
|
||||
}
|
||||
|
||||
void all_min(const array&, array&, Stream) override {
|
||||
throw std::runtime_error(
|
||||
"Communication not implemented in an empty distributed group.");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
@ -21,6 +21,8 @@ class GroupImpl {
|
||||
virtual void all_gather(const array& input, array& output, Stream stream) = 0;
|
||||
virtual void send(const array& input, int dst, Stream stream) = 0;
|
||||
virtual void recv(array& out, int src, Stream stream) = 0;
|
||||
virtual void all_max(const array& input, array& output, Stream stream) = 0;
|
||||
virtual void all_min(const array& input, array& output, Stream stream) = 0;
|
||||
};
|
||||
|
||||
/* Perform an all reduce sum operation */
|
||||
@ -35,4 +37,10 @@ void send(Group group, const array& input, int dst, Stream stream);
|
||||
/** Recv an array from the src rank */
|
||||
void recv(Group group, array& out, int src, Stream stream);
|
||||
|
||||
/** Max reduction */
|
||||
void all_max(Group group, const array& input, array& output, Stream stream);
|
||||
|
||||
/** Min reduction */
|
||||
void all_min(Group group, const array& input, array& output, Stream stream);
|
||||
|
||||
} // namespace mlx::core::distributed::detail
|
||||
|
@ -93,6 +93,8 @@ struct MPIWrapper {
|
||||
|
||||
// Ops
|
||||
LOAD_SYMBOL(ompi_mpi_op_sum, op_sum_);
|
||||
LOAD_SYMBOL(ompi_mpi_op_max, op_max_);
|
||||
LOAD_SYMBOL(ompi_mpi_op_min, op_min_);
|
||||
|
||||
// Datatypes
|
||||
LOAD_SYMBOL(ompi_mpi_c_bool, mpi_bool_);
|
||||
@ -191,6 +193,14 @@ struct MPIWrapper {
|
||||
}
|
||||
}
|
||||
|
||||
MPI_Op op_max(const array& arr) {
|
||||
return op_max_;
|
||||
}
|
||||
|
||||
MPI_Op op_min(const array& arr) {
|
||||
return op_min_;
|
||||
}
|
||||
|
||||
void* libmpi_handle_;
|
||||
|
||||
// API
|
||||
@ -219,6 +229,8 @@ struct MPIWrapper {
|
||||
MPI_Op op_sum_;
|
||||
MPI_Op op_sum_f16_;
|
||||
MPI_Op op_sum_bf16_;
|
||||
MPI_Op op_max_;
|
||||
MPI_Op op_min_;
|
||||
|
||||
// Datatypes
|
||||
MPI_Datatype mpi_bool_;
|
||||
@ -306,6 +318,36 @@ class MPIGroup : public GroupImpl {
|
||||
comm_);
|
||||
}
|
||||
|
||||
void all_max(const array& input, array& output, Stream stream) override {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(input);
|
||||
encoder.set_output_array(output);
|
||||
encoder.dispatch(
|
||||
mpi().all_reduce,
|
||||
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE
|
||||
: input.data<void>(),
|
||||
output.data<void>(),
|
||||
input.size(),
|
||||
mpi().datatype(input),
|
||||
mpi().op_max(input),
|
||||
comm_);
|
||||
}
|
||||
|
||||
void all_min(const array& input, array& output, Stream stream) override {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(input);
|
||||
encoder.set_output_array(output);
|
||||
encoder.dispatch(
|
||||
mpi().all_reduce,
|
||||
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE
|
||||
: input.data<void>(),
|
||||
output.data<void>(),
|
||||
input.size(),
|
||||
mpi().datatype(input),
|
||||
mpi().op_min(input),
|
||||
comm_);
|
||||
}
|
||||
|
||||
void all_gather(const array& input, array& output, Stream stream) override {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(input);
|
||||
|
@ -36,6 +36,40 @@ array all_sum(
|
||||
{x});
|
||||
}
|
||||
|
||||
array all_max(
|
||||
const array& x,
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto group = to_group(group_);
|
||||
|
||||
if (group.size() == 1) {
|
||||
return x;
|
||||
}
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_shared<AllReduce>(
|
||||
to_stream(s, Device::cpu), group, AllReduce::Max),
|
||||
{x});
|
||||
}
|
||||
|
||||
array all_min(
|
||||
const array& x,
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto group = to_group(group_);
|
||||
|
||||
if (group.size() == 1) {
|
||||
return x;
|
||||
}
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_shared<AllReduce>(
|
||||
to_stream(s, Device::cpu), group, AllReduce::Min),
|
||||
{x});
|
||||
}
|
||||
|
||||
array all_gather(
|
||||
const array& x,
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
|
@ -38,4 +38,14 @@ array recv_like(
|
||||
std::optional<Group> group = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
array all_max(
|
||||
const array& x,
|
||||
std::optional<Group> group = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
array all_min(
|
||||
const array& x,
|
||||
std::optional<Group> group = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
} // namespace mlx::core::distributed
|
||||
|
@ -15,8 +15,14 @@ std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
return {{all_sum(inputs[0], group(), stream())}, axes};
|
||||
case Max:
|
||||
return {{all_max(inputs[0], group(), stream())}, axes};
|
||||
case Min:
|
||||
return {{all_min(inputs[0], group(), stream())}, axes};
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
|
||||
throw std::runtime_error(
|
||||
"Only all reduce sum, max and min are supported for now");
|
||||
}
|
||||
}
|
||||
|
||||
@ -27,8 +33,13 @@ std::vector<array> AllReduce::jvp(
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
return {all_sum(tangents[0], group(), stream())};
|
||||
case Max:
|
||||
return {all_max(tangents[0], group(), stream())};
|
||||
case Min:
|
||||
return {all_min(tangents[0], group(), stream())};
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
throw std::runtime_error(
|
||||
"Only all reduce sum, max and min are supported for now");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -503,15 +503,38 @@ std::vector<int> make_connections(
|
||||
|
||||
return sockets;
|
||||
}
|
||||
template <typename T>
|
||||
struct SumOp {
|
||||
void operator()(const T* input, T* output, size_t N) {
|
||||
while (N-- > 0) {
|
||||
*output += *input;
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void sum_inplace(const T* input, T* output, size_t N) {
|
||||
while (N-- > 0) {
|
||||
*output += *input;
|
||||
input++;
|
||||
output++;
|
||||
struct MaxOp {
|
||||
void operator()(const T* input, T* output, size_t N) {
|
||||
while (N-- > 0) {
|
||||
*output = std::max(*output, *input);
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MinOp {
|
||||
void operator()(const T* input, T* output, size_t N) {
|
||||
while (N-- > 0) {
|
||||
*output = std::min(*output, *input);
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
@ -605,7 +628,18 @@ class RingGroup : public GroupImpl {
|
||||
}
|
||||
|
||||
void all_sum(const array& input, array& output, Stream stream) override {
|
||||
SWITCH_TYPE(output, all_sum<T>(input, output, stream));
|
||||
SWITCH_TYPE(
|
||||
output, all_reduce<T, SumOp<T>>(input, output, stream, SumOp<T>()));
|
||||
}
|
||||
|
||||
void all_max(const array& input, array& output, Stream stream) override {
|
||||
SWITCH_TYPE(
|
||||
output, all_reduce<T, MaxOp<T>>(input, output, stream, MaxOp<T>()));
|
||||
}
|
||||
|
||||
void all_min(const array& input, array& output, Stream stream) override {
|
||||
SWITCH_TYPE(
|
||||
output, all_reduce<T, MinOp<T>>(input, output, stream, MinOp<T>()));
|
||||
}
|
||||
|
||||
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||
@ -694,13 +728,17 @@ class RingGroup : public GroupImpl {
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
void all_sum(const array& input, array& output, Stream stream) {
|
||||
template <typename T, typename ReduceOp>
|
||||
void all_reduce(
|
||||
const array& input,
|
||||
array& output,
|
||||
Stream stream,
|
||||
ReduceOp reduce_op) {
|
||||
auto in_ptr = input.data<char>();
|
||||
auto out_ptr = output.data<char>();
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_output_array(output);
|
||||
encoder.dispatch([in_ptr, out_ptr, size = input.size(), this]() {
|
||||
encoder.dispatch([in_ptr, out_ptr, size = input.size(), this, reduce_op]() {
|
||||
// If the input data cannot be split into size_ segments then copy it and
|
||||
// all reduce a local buffer prefilled with 0s.
|
||||
size_t nbytes = size * sizeof(T);
|
||||
@ -717,13 +755,14 @@ class RingGroup : public GroupImpl {
|
||||
char buffer[1024];
|
||||
std::memset(buffer, 0, size_ * sizeof(T));
|
||||
std::memcpy(buffer, in_ptr, nbytes);
|
||||
all_sum_impl<T>(
|
||||
all_reduce_impl<T, ReduceOp>(
|
||||
reinterpret_cast<T*>(buffers_.data()),
|
||||
reinterpret_cast<T*>(buffer),
|
||||
size_,
|
||||
sockets_right_[0],
|
||||
sockets_left_[0],
|
||||
-1);
|
||||
-1,
|
||||
reduce_op);
|
||||
std::memcpy(out_ptr, buffer, nbytes);
|
||||
return;
|
||||
}
|
||||
@ -746,7 +785,7 @@ class RingGroup : public GroupImpl {
|
||||
|
||||
for (int i = 0; i < n_reduces; i++) {
|
||||
all_sums.emplace_back(pool_.enqueue(std::bind(
|
||||
&RingGroup::all_sum_impl<T>,
|
||||
&RingGroup::all_reduce_impl<T, ReduceOp>,
|
||||
this,
|
||||
reinterpret_cast<T*>(
|
||||
buffers_.data() + i * ALL_SUM_SIZE * ALL_SUM_BUFFERS),
|
||||
@ -754,7 +793,8 @@ class RingGroup : public GroupImpl {
|
||||
std::min(size, (i + 1) * step) - i * step,
|
||||
sockets_right_[i / 2],
|
||||
sockets_left_[i / 2],
|
||||
(i % 2) ? -1 : 1)));
|
||||
(i % 2) ? -1 : 1,
|
||||
reduce_op)));
|
||||
}
|
||||
for (auto& f : all_sums) {
|
||||
f.wait();
|
||||
@ -762,14 +802,15 @@ class RingGroup : public GroupImpl {
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void all_sum_impl(
|
||||
template <typename T, typename ReduceOp>
|
||||
void all_reduce_impl(
|
||||
T* buffer,
|
||||
T* data,
|
||||
size_t data_size,
|
||||
int socket_right,
|
||||
int socket_left,
|
||||
int direction) {
|
||||
int direction,
|
||||
ReduceOp reduce_op) {
|
||||
// Choose which socket we send to and recv from
|
||||
int socket_send = (direction < 0) ? socket_right : socket_left;
|
||||
int socket_recv = (direction < 0) ? socket_left : socket_right;
|
||||
@ -846,7 +887,7 @@ class RingGroup : public GroupImpl {
|
||||
sends[b].wait();
|
||||
recvs[b].wait();
|
||||
if (2 * j < send_plan.size()) {
|
||||
sum_inplace<T>(
|
||||
reduce_op(
|
||||
recv_buffers[j % ALL_SUM_BUFFERS],
|
||||
data + recv_plan[j].first,
|
||||
recv_plan[j].second - recv_plan[j].first);
|
||||
|
@ -117,7 +117,64 @@ void init_distributed(nb::module_& parent_module) {
|
||||
Returns:
|
||||
array: The sum of all ``x`` arrays.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"all_max",
|
||||
[](const ScalarOrArray& x,
|
||||
std::optional<mx::distributed::Group> group,
|
||||
mx::StreamOrDevice s) {
|
||||
return mx::distributed::all_max(to_array(x), group, s);
|
||||
},
|
||||
"x"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def all_max(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
All reduce max.
|
||||
|
||||
Find the maximum of the ``x`` arrays from all processes in the group.
|
||||
|
||||
Args:
|
||||
x (array): Input array.
|
||||
group (Group): The group of processes that will participate in the
|
||||
reduction. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
array: The maximum of all ``x`` arrays.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"all_min",
|
||||
[](const ScalarOrArray& x,
|
||||
std::optional<mx::distributed::Group> group,
|
||||
mx::StreamOrDevice s) {
|
||||
return mx::distributed::all_min(to_array(x), group, s);
|
||||
},
|
||||
"x"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def all_min(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
All reduce min.
|
||||
|
||||
Find the minimum of the ``x`` arrays from all processes in the group.
|
||||
|
||||
Args:
|
||||
x (array): Input array.
|
||||
group (Group): The group of processes that will participate in the
|
||||
reduction. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
array: The minimum of all ``x`` arrays.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"all_gather",
|
||||
[](const ScalarOrArray& x,
|
||||
|
@ -124,6 +124,22 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
x = mx.distributed.recv_like(x, neighbor, group=pairs)
|
||||
mx.eval(y, x)
|
||||
|
||||
def test_min_max(self):
|
||||
world = mx.distributed.init()
|
||||
base = mx.arange(16).reshape(4, 4)
|
||||
x = base + world.rank() * 32
|
||||
|
||||
def _test_reduction(reduction: str = "all_max"):
|
||||
|
||||
target = base + ((world.size() - 1) * 16) * (reduction == "max")
|
||||
reducer = getattr(mx.distributed, reduction)
|
||||
y = reducer(x)
|
||||
|
||||
self.assertTrue(mx.allclose(y, target))
|
||||
|
||||
for reduction in ["all_max", "all_min"]:
|
||||
_test_reduction(reduction)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -44,17 +44,23 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
(1024, 1024),
|
||||
]
|
||||
key = mx.random.key(0)
|
||||
reductions = ["min", "max", "sum"]
|
||||
|
||||
for dt, rtol in dtypes:
|
||||
for sh in sizes:
|
||||
x = (
|
||||
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
|
||||
).astype(dt)
|
||||
y = mx.distributed.all_sum(x[world.rank()])
|
||||
z = sum(
|
||||
x[i] for i in range(world.size())
|
||||
) # to ensure that we don't sum to int32
|
||||
maxrelerror = ((y - z).abs() / z.abs()).max()
|
||||
self.assertLessEqual(maxrelerror, rtol)
|
||||
for reduction in reductions:
|
||||
reducer_distributed = getattr(mx.distributed, f"all_{reduction}")
|
||||
y = reducer_distributed(x[world.rank()])
|
||||
|
||||
reducer = getattr(mx, reduction)
|
||||
z = reducer(x, axis=0)
|
||||
mx.eval(y, z)
|
||||
|
||||
maxrelerror = ((y - z).abs() / z.abs()).max()
|
||||
self.assertLessEqual(maxrelerror, rtol)
|
||||
|
||||
def test_all_gather(self):
|
||||
world = mx.distributed.init()
|
||||
|
Loading…
Reference in New Issue
Block a user