mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-27 20:07:59 +08:00
Min / max reductions (#2041)
This commit is contained in:
parent
9ecefd56db
commit
515f104926
@ -46,8 +46,14 @@ void AllReduce::eval_cpu(
|
|||||||
case Sum:
|
case Sum:
|
||||||
distributed::detail::all_sum(group(), in, outputs[0], stream());
|
distributed::detail::all_sum(group(), in, outputs[0], stream());
|
||||||
break;
|
break;
|
||||||
|
case Max:
|
||||||
|
distributed::detail::all_max(group(), in, outputs[0], stream());
|
||||||
|
break;
|
||||||
|
case Min:
|
||||||
|
distributed::detail::all_min(group(), in, outputs[0], stream());
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
throw std::runtime_error("Only all reduce sum, min and max are supported for now");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,6 +15,14 @@ void all_sum(Group group, const array& input, array& output, Stream stream) {
|
|||||||
group.raw_group()->all_sum(input, output, stream);
|
group.raw_group()->all_sum(input, output, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void all_max(Group group, const array& input, array& output, Stream stream) {
|
||||||
|
group.raw_group()->all_max(input, output, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_min(Group group, const array& input, array& output, Stream stream) {
|
||||||
|
group.raw_group()->all_min(input, output, stream);
|
||||||
|
}
|
||||||
|
|
||||||
void all_gather(Group group, const array& input, array& output, Stream stream) {
|
void all_gather(Group group, const array& input, array& output, Stream stream) {
|
||||||
group.raw_group()->all_gather(input, output, stream);
|
group.raw_group()->all_gather(input, output, stream);
|
||||||
}
|
}
|
||||||
@ -57,6 +65,16 @@ class EmptyGroup : public GroupImpl {
|
|||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"Communication not implemented in an empty distributed group.");
|
"Communication not implemented in an empty distributed group.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void all_max(const array&, array&, Stream) override {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Communication not implemented in an empty distributed group.");
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_min(const array&, array&, Stream) override {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Communication not implemented in an empty distributed group.");
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
@ -21,6 +21,8 @@ class GroupImpl {
|
|||||||
virtual void all_gather(const array& input, array& output, Stream stream) = 0;
|
virtual void all_gather(const array& input, array& output, Stream stream) = 0;
|
||||||
virtual void send(const array& input, int dst, Stream stream) = 0;
|
virtual void send(const array& input, int dst, Stream stream) = 0;
|
||||||
virtual void recv(array& out, int src, Stream stream) = 0;
|
virtual void recv(array& out, int src, Stream stream) = 0;
|
||||||
|
virtual void all_max(const array& input, array& output, Stream stream) = 0;
|
||||||
|
virtual void all_min(const array& input, array& output, Stream stream) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
/* Perform an all reduce sum operation */
|
/* Perform an all reduce sum operation */
|
||||||
@ -35,4 +37,10 @@ void send(Group group, const array& input, int dst, Stream stream);
|
|||||||
/** Recv an array from the src rank */
|
/** Recv an array from the src rank */
|
||||||
void recv(Group group, array& out, int src, Stream stream);
|
void recv(Group group, array& out, int src, Stream stream);
|
||||||
|
|
||||||
|
/** Max reduction */
|
||||||
|
void all_max(Group group, const array& input, array& output, Stream stream);
|
||||||
|
|
||||||
|
/** Min reduction */
|
||||||
|
void all_min(Group group, const array& input, array& output, Stream stream);
|
||||||
|
|
||||||
} // namespace mlx::core::distributed::detail
|
} // namespace mlx::core::distributed::detail
|
||||||
|
@ -93,6 +93,8 @@ struct MPIWrapper {
|
|||||||
|
|
||||||
// Ops
|
// Ops
|
||||||
LOAD_SYMBOL(ompi_mpi_op_sum, op_sum_);
|
LOAD_SYMBOL(ompi_mpi_op_sum, op_sum_);
|
||||||
|
LOAD_SYMBOL(ompi_mpi_op_max, op_max_);
|
||||||
|
LOAD_SYMBOL(ompi_mpi_op_min, op_min_);
|
||||||
|
|
||||||
// Datatypes
|
// Datatypes
|
||||||
LOAD_SYMBOL(ompi_mpi_c_bool, mpi_bool_);
|
LOAD_SYMBOL(ompi_mpi_c_bool, mpi_bool_);
|
||||||
@ -191,6 +193,14 @@ struct MPIWrapper {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MPI_Op op_max(const array& arr) {
|
||||||
|
return op_max_;
|
||||||
|
}
|
||||||
|
|
||||||
|
MPI_Op op_min(const array& arr) {
|
||||||
|
return op_min_;
|
||||||
|
}
|
||||||
|
|
||||||
void* libmpi_handle_;
|
void* libmpi_handle_;
|
||||||
|
|
||||||
// API
|
// API
|
||||||
@ -219,6 +229,8 @@ struct MPIWrapper {
|
|||||||
MPI_Op op_sum_;
|
MPI_Op op_sum_;
|
||||||
MPI_Op op_sum_f16_;
|
MPI_Op op_sum_f16_;
|
||||||
MPI_Op op_sum_bf16_;
|
MPI_Op op_sum_bf16_;
|
||||||
|
MPI_Op op_max_;
|
||||||
|
MPI_Op op_min_;
|
||||||
|
|
||||||
// Datatypes
|
// Datatypes
|
||||||
MPI_Datatype mpi_bool_;
|
MPI_Datatype mpi_bool_;
|
||||||
@ -306,6 +318,36 @@ class MPIGroup : public GroupImpl {
|
|||||||
comm_);
|
comm_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void all_max(const array& input, array& output, Stream stream) override {
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_input_array(input);
|
||||||
|
encoder.set_output_array(output);
|
||||||
|
encoder.dispatch(
|
||||||
|
mpi().all_reduce,
|
||||||
|
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE
|
||||||
|
: input.data<void>(),
|
||||||
|
output.data<void>(),
|
||||||
|
input.size(),
|
||||||
|
mpi().datatype(input),
|
||||||
|
mpi().op_max(input),
|
||||||
|
comm_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_min(const array& input, array& output, Stream stream) override {
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_input_array(input);
|
||||||
|
encoder.set_output_array(output);
|
||||||
|
encoder.dispatch(
|
||||||
|
mpi().all_reduce,
|
||||||
|
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE
|
||||||
|
: input.data<void>(),
|
||||||
|
output.data<void>(),
|
||||||
|
input.size(),
|
||||||
|
mpi().datatype(input),
|
||||||
|
mpi().op_min(input),
|
||||||
|
comm_);
|
||||||
|
}
|
||||||
|
|
||||||
void all_gather(const array& input, array& output, Stream stream) override {
|
void all_gather(const array& input, array& output, Stream stream) override {
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
encoder.set_input_array(input);
|
encoder.set_input_array(input);
|
||||||
|
@ -36,6 +36,40 @@ array all_sum(
|
|||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array all_max(
|
||||||
|
const array& x,
|
||||||
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
auto group = to_group(group_);
|
||||||
|
|
||||||
|
if (group.size() == 1) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
return array(
|
||||||
|
x.shape(),
|
||||||
|
x.dtype(),
|
||||||
|
std::make_shared<AllReduce>(
|
||||||
|
to_stream(s, Device::cpu), group, AllReduce::Max),
|
||||||
|
{x});
|
||||||
|
}
|
||||||
|
|
||||||
|
array all_min(
|
||||||
|
const array& x,
|
||||||
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
auto group = to_group(group_);
|
||||||
|
|
||||||
|
if (group.size() == 1) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
return array(
|
||||||
|
x.shape(),
|
||||||
|
x.dtype(),
|
||||||
|
std::make_shared<AllReduce>(
|
||||||
|
to_stream(s, Device::cpu), group, AllReduce::Min),
|
||||||
|
{x});
|
||||||
|
}
|
||||||
|
|
||||||
array all_gather(
|
array all_gather(
|
||||||
const array& x,
|
const array& x,
|
||||||
std::optional<Group> group_ /* = std::nullopt */,
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
|
@ -38,4 +38,14 @@ array recv_like(
|
|||||||
std::optional<Group> group = std::nullopt,
|
std::optional<Group> group = std::nullopt,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
array all_max(
|
||||||
|
const array& x,
|
||||||
|
std::optional<Group> group = std::nullopt,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
array all_min(
|
||||||
|
const array& x,
|
||||||
|
std::optional<Group> group = std::nullopt,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
} // namespace mlx::core::distributed
|
} // namespace mlx::core::distributed
|
||||||
|
@ -15,8 +15,14 @@ std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
|
|||||||
switch (reduce_type_) {
|
switch (reduce_type_) {
|
||||||
case Sum:
|
case Sum:
|
||||||
return {{all_sum(inputs[0], group(), stream())}, axes};
|
return {{all_sum(inputs[0], group(), stream())}, axes};
|
||||||
|
case Max:
|
||||||
|
return {{all_max(inputs[0], group(), stream())}, axes};
|
||||||
|
case Min:
|
||||||
|
return {{all_min(inputs[0], group(), stream())}, axes};
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Only all reduce sum, max and min are supported for now");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -27,8 +33,13 @@ std::vector<array> AllReduce::jvp(
|
|||||||
switch (reduce_type_) {
|
switch (reduce_type_) {
|
||||||
case Sum:
|
case Sum:
|
||||||
return {all_sum(tangents[0], group(), stream())};
|
return {all_sum(tangents[0], group(), stream())};
|
||||||
|
case Max:
|
||||||
|
return {all_max(tangents[0], group(), stream())};
|
||||||
|
case Min:
|
||||||
|
return {all_min(tangents[0], group(), stream())};
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
throw std::runtime_error(
|
||||||
|
"Only all reduce sum, max and min are supported for now");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -503,15 +503,38 @@ std::vector<int> make_connections(
|
|||||||
|
|
||||||
return sockets;
|
return sockets;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void sum_inplace(const T* input, T* output, size_t N) {
|
struct SumOp {
|
||||||
|
void operator()(const T* input, T* output, size_t N) {
|
||||||
while (N-- > 0) {
|
while (N-- > 0) {
|
||||||
*output += *input;
|
*output += *input;
|
||||||
input++;
|
input++;
|
||||||
output++;
|
output++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct MaxOp {
|
||||||
|
void operator()(const T* input, T* output, size_t N) {
|
||||||
|
while (N-- > 0) {
|
||||||
|
*output = std::max(*output, *input);
|
||||||
|
input++;
|
||||||
|
output++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct MinOp {
|
||||||
|
void operator()(const T* input, T* output, size_t N) {
|
||||||
|
while (N-- > 0) {
|
||||||
|
*output = std::min(*output, *input);
|
||||||
|
input++;
|
||||||
|
output++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
@ -605,7 +628,18 @@ class RingGroup : public GroupImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void all_sum(const array& input, array& output, Stream stream) override {
|
void all_sum(const array& input, array& output, Stream stream) override {
|
||||||
SWITCH_TYPE(output, all_sum<T>(input, output, stream));
|
SWITCH_TYPE(
|
||||||
|
output, all_reduce<T, SumOp<T>>(input, output, stream, SumOp<T>()));
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_max(const array& input, array& output, Stream stream) override {
|
||||||
|
SWITCH_TYPE(
|
||||||
|
output, all_reduce<T, MaxOp<T>>(input, output, stream, MaxOp<T>()));
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_min(const array& input, array& output, Stream stream) override {
|
||||||
|
SWITCH_TYPE(
|
||||||
|
output, all_reduce<T, MinOp<T>>(input, output, stream, MinOp<T>()));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||||
@ -694,13 +728,17 @@ class RingGroup : public GroupImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
template <typename T>
|
template <typename T, typename ReduceOp>
|
||||||
void all_sum(const array& input, array& output, Stream stream) {
|
void all_reduce(
|
||||||
|
const array& input,
|
||||||
|
array& output,
|
||||||
|
Stream stream,
|
||||||
|
ReduceOp reduce_op) {
|
||||||
auto in_ptr = input.data<char>();
|
auto in_ptr = input.data<char>();
|
||||||
auto out_ptr = output.data<char>();
|
auto out_ptr = output.data<char>();
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
encoder.set_output_array(output);
|
encoder.set_output_array(output);
|
||||||
encoder.dispatch([in_ptr, out_ptr, size = input.size(), this]() {
|
encoder.dispatch([in_ptr, out_ptr, size = input.size(), this, reduce_op]() {
|
||||||
// If the input data cannot be split into size_ segments then copy it and
|
// If the input data cannot be split into size_ segments then copy it and
|
||||||
// all reduce a local buffer prefilled with 0s.
|
// all reduce a local buffer prefilled with 0s.
|
||||||
size_t nbytes = size * sizeof(T);
|
size_t nbytes = size * sizeof(T);
|
||||||
@ -717,13 +755,14 @@ class RingGroup : public GroupImpl {
|
|||||||
char buffer[1024];
|
char buffer[1024];
|
||||||
std::memset(buffer, 0, size_ * sizeof(T));
|
std::memset(buffer, 0, size_ * sizeof(T));
|
||||||
std::memcpy(buffer, in_ptr, nbytes);
|
std::memcpy(buffer, in_ptr, nbytes);
|
||||||
all_sum_impl<T>(
|
all_reduce_impl<T, ReduceOp>(
|
||||||
reinterpret_cast<T*>(buffers_.data()),
|
reinterpret_cast<T*>(buffers_.data()),
|
||||||
reinterpret_cast<T*>(buffer),
|
reinterpret_cast<T*>(buffer),
|
||||||
size_,
|
size_,
|
||||||
sockets_right_[0],
|
sockets_right_[0],
|
||||||
sockets_left_[0],
|
sockets_left_[0],
|
||||||
-1);
|
-1,
|
||||||
|
reduce_op);
|
||||||
std::memcpy(out_ptr, buffer, nbytes);
|
std::memcpy(out_ptr, buffer, nbytes);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -746,7 +785,7 @@ class RingGroup : public GroupImpl {
|
|||||||
|
|
||||||
for (int i = 0; i < n_reduces; i++) {
|
for (int i = 0; i < n_reduces; i++) {
|
||||||
all_sums.emplace_back(pool_.enqueue(std::bind(
|
all_sums.emplace_back(pool_.enqueue(std::bind(
|
||||||
&RingGroup::all_sum_impl<T>,
|
&RingGroup::all_reduce_impl<T, ReduceOp>,
|
||||||
this,
|
this,
|
||||||
reinterpret_cast<T*>(
|
reinterpret_cast<T*>(
|
||||||
buffers_.data() + i * ALL_SUM_SIZE * ALL_SUM_BUFFERS),
|
buffers_.data() + i * ALL_SUM_SIZE * ALL_SUM_BUFFERS),
|
||||||
@ -754,7 +793,8 @@ class RingGroup : public GroupImpl {
|
|||||||
std::min(size, (i + 1) * step) - i * step,
|
std::min(size, (i + 1) * step) - i * step,
|
||||||
sockets_right_[i / 2],
|
sockets_right_[i / 2],
|
||||||
sockets_left_[i / 2],
|
sockets_left_[i / 2],
|
||||||
(i % 2) ? -1 : 1)));
|
(i % 2) ? -1 : 1,
|
||||||
|
reduce_op)));
|
||||||
}
|
}
|
||||||
for (auto& f : all_sums) {
|
for (auto& f : all_sums) {
|
||||||
f.wait();
|
f.wait();
|
||||||
@ -762,14 +802,15 @@ class RingGroup : public GroupImpl {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, typename ReduceOp>
|
||||||
void all_sum_impl(
|
void all_reduce_impl(
|
||||||
T* buffer,
|
T* buffer,
|
||||||
T* data,
|
T* data,
|
||||||
size_t data_size,
|
size_t data_size,
|
||||||
int socket_right,
|
int socket_right,
|
||||||
int socket_left,
|
int socket_left,
|
||||||
int direction) {
|
int direction,
|
||||||
|
ReduceOp reduce_op) {
|
||||||
// Choose which socket we send to and recv from
|
// Choose which socket we send to and recv from
|
||||||
int socket_send = (direction < 0) ? socket_right : socket_left;
|
int socket_send = (direction < 0) ? socket_right : socket_left;
|
||||||
int socket_recv = (direction < 0) ? socket_left : socket_right;
|
int socket_recv = (direction < 0) ? socket_left : socket_right;
|
||||||
@ -846,7 +887,7 @@ class RingGroup : public GroupImpl {
|
|||||||
sends[b].wait();
|
sends[b].wait();
|
||||||
recvs[b].wait();
|
recvs[b].wait();
|
||||||
if (2 * j < send_plan.size()) {
|
if (2 * j < send_plan.size()) {
|
||||||
sum_inplace<T>(
|
reduce_op(
|
||||||
recv_buffers[j % ALL_SUM_BUFFERS],
|
recv_buffers[j % ALL_SUM_BUFFERS],
|
||||||
data + recv_plan[j].first,
|
data + recv_plan[j].first,
|
||||||
recv_plan[j].second - recv_plan[j].first);
|
recv_plan[j].second - recv_plan[j].first);
|
||||||
|
@ -117,7 +117,64 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
Returns:
|
Returns:
|
||||||
array: The sum of all ``x`` arrays.
|
array: The sum of all ``x`` arrays.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"all_max",
|
||||||
|
[](const ScalarOrArray& x,
|
||||||
|
std::optional<mx::distributed::Group> group,
|
||||||
|
mx::StreamOrDevice s) {
|
||||||
|
return mx::distributed::all_max(to_array(x), group, s);
|
||||||
|
},
|
||||||
|
"x"_a,
|
||||||
|
nb::kw_only(),
|
||||||
|
"group"_a = nb::none(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def all_max(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
All reduce max.
|
||||||
|
|
||||||
|
Find the maximum of the ``x`` arrays from all processes in the group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (array): Input array.
|
||||||
|
group (Group): The group of processes that will participate in the
|
||||||
|
reduction. If set to ``None`` the global group is used. Default:
|
||||||
|
``None``.
|
||||||
|
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||||
|
in which case the default stream of the default device is used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The maximum of all ``x`` arrays.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"all_min",
|
||||||
|
[](const ScalarOrArray& x,
|
||||||
|
std::optional<mx::distributed::Group> group,
|
||||||
|
mx::StreamOrDevice s) {
|
||||||
|
return mx::distributed::all_min(to_array(x), group, s);
|
||||||
|
},
|
||||||
|
"x"_a,
|
||||||
|
nb::kw_only(),
|
||||||
|
"group"_a = nb::none(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def all_min(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
All reduce min.
|
||||||
|
|
||||||
|
Find the minimum of the ``x`` arrays from all processes in the group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (array): Input array.
|
||||||
|
group (Group): The group of processes that will participate in the
|
||||||
|
reduction. If set to ``None`` the global group is used. Default:
|
||||||
|
``None``.
|
||||||
|
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||||
|
in which case the default stream of the default device is used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The minimum of all ``x`` arrays.
|
||||||
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"all_gather",
|
"all_gather",
|
||||||
[](const ScalarOrArray& x,
|
[](const ScalarOrArray& x,
|
||||||
|
@ -124,6 +124,22 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
|||||||
x = mx.distributed.recv_like(x, neighbor, group=pairs)
|
x = mx.distributed.recv_like(x, neighbor, group=pairs)
|
||||||
mx.eval(y, x)
|
mx.eval(y, x)
|
||||||
|
|
||||||
|
def test_min_max(self):
|
||||||
|
world = mx.distributed.init()
|
||||||
|
base = mx.arange(16).reshape(4, 4)
|
||||||
|
x = base + world.rank() * 32
|
||||||
|
|
||||||
|
def _test_reduction(reduction: str = "all_max"):
|
||||||
|
|
||||||
|
target = base + ((world.size() - 1) * 16) * (reduction == "max")
|
||||||
|
reducer = getattr(mx.distributed, reduction)
|
||||||
|
y = reducer(x)
|
||||||
|
|
||||||
|
self.assertTrue(mx.allclose(y, target))
|
||||||
|
|
||||||
|
for reduction in ["all_max", "all_min"]:
|
||||||
|
_test_reduction(reduction)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -44,15 +44,21 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
|||||||
(1024, 1024),
|
(1024, 1024),
|
||||||
]
|
]
|
||||||
key = mx.random.key(0)
|
key = mx.random.key(0)
|
||||||
|
reductions = ["min", "max", "sum"]
|
||||||
|
|
||||||
for dt, rtol in dtypes:
|
for dt, rtol in dtypes:
|
||||||
for sh in sizes:
|
for sh in sizes:
|
||||||
x = (
|
x = (
|
||||||
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
|
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
|
||||||
).astype(dt)
|
).astype(dt)
|
||||||
y = mx.distributed.all_sum(x[world.rank()])
|
for reduction in reductions:
|
||||||
z = sum(
|
reducer_distributed = getattr(mx.distributed, f"all_{reduction}")
|
||||||
x[i] for i in range(world.size())
|
y = reducer_distributed(x[world.rank()])
|
||||||
) # to ensure that we don't sum to int32
|
|
||||||
|
reducer = getattr(mx, reduction)
|
||||||
|
z = reducer(x, axis=0)
|
||||||
|
mx.eval(y, z)
|
||||||
|
|
||||||
maxrelerror = ((y - z).abs() / z.abs()).max()
|
maxrelerror = ((y - z).abs() / z.abs()).max()
|
||||||
self.assertLessEqual(maxrelerror, rtol)
|
self.assertLessEqual(maxrelerror, rtol)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user