Nccl reduce scatter, all gather (#2727)

* Added reduce scatter and all gather for nccl

* fix unused import, delete unused file

* small fix

* deleted useless condition

* fixed comments

* fix bug in eval_gpu, renamed to sum_scatter, fix docs

* final fix docs

* remove and

* Update mlx/distributed/mpi/mpi.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* fix broken set input output

* fixes set output

* typo

* fix typo

* no cpu, no gpu for reduce scatter

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Anastasiia Filippova
2025-11-05 17:21:11 +01:00
committed by GitHub
parent 761f901a41
commit 27778156dc
19 changed files with 351 additions and 311 deletions

View File

@@ -229,7 +229,7 @@ void init_distributed(nb::module_& parent_module) {
x (array): Input array.
dst (int): Rank of the destination process in the group.
group (Group): The group of processes that will participate in the
sned. If set to ``None`` the global group is used. Default:
send. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
@@ -301,4 +301,36 @@ void init_distributed(nb::module_& parent_module) {
Returns:
array: The array that was received from ``src``.
)pbdoc");
m.def(
"sum_scatter",
[](const ScalarOrArray& x,
std::optional<mx::distributed::Group> group,
mx::StreamOrDevice s) {
return mx::distributed::sum_scatter(to_array(x), group, s);
},
"x"_a,
nb::kw_only(),
"group"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def sum_scatter(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Sum ``x`` across all processes in the group and shard the result along the first axis across ranks.
``x.shape[0]`` must be divisible by the group size.
The result is equivalent to ``all_sum(x)[rank*chunk_size:(rank+1)*chunk_size]``, where ``chunk_size = x.shape[0] // group.size()`` and ``rank`` is the rank of this process in the group.
Note: ``all_sum`` is mentioned only for illustration; the actual implementation does not perform ``all_sum`` and uses a single reduce-scatter collective instead.
Currently supported only for the NCCL backend.
Args:
x (array): Input array.
group (Group): The group of processes that will participate in the
sum scatter. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The output array with shape ``[x.shape[0] // group.size(), *x.shape[1:]]``.
)pbdoc");
}