Min / max reductions (#2041)

This commit is contained in:
Anastasiia Filippova
2025-04-10 08:22:20 +02:00
committed by GitHub
parent 9ecefd56db
commit 515f104926
11 changed files with 276 additions and 27 deletions

View File

@@ -117,7 +117,64 @@ void init_distributed(nb::module_& parent_module) {
Returns:
array: The sum of all ``x`` arrays.
)pbdoc");
m.def(
"all_max",
[](const ScalarOrArray& x,
std::optional<mx::distributed::Group> group,
mx::StreamOrDevice s) {
return mx::distributed::all_max(to_array(x), group, s);
},
"x"_a,
nb::kw_only(),
"group"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def all_max(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
All reduce max.
Find the maximum of the ``x`` arrays from all processes in the group.
Args:
x (array): Input array.
group (Group): The group of processes that will participate in the
reduction. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The maximum of all ``x`` arrays.
)pbdoc");
m.def(
"all_min",
[](const ScalarOrArray& x,
std::optional<mx::distributed::Group> group,
mx::StreamOrDevice s) {
return mx::distributed::all_min(to_array(x), group, s);
},
"x"_a,
nb::kw_only(),
"group"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def all_min(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
All reduce min.
Find the minimum of the ``x`` arrays from all processes in the group.
Args:
x (array): Input array.
group (Group): The group of processes that will participate in the
reduction. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The minimum of all ``x`` arrays.
)pbdoc");
m.def(
"all_gather",
[](const ScalarOrArray& x,