Nccl reduce scatter, all gather (#2727)

* Added reduce scatter and all gather for nccl

* fix unused import, delete unused file

* small fix

* deleted useless condition

* fixed comments

* fix bug in eval_gpu, renamed to sum_scatter, fix docs

* final fix docs

* remove and

* Update mlx/distributed/mpi/mpi.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* fix broken set input output

* fixes set output

* typo

* fix typo

* no cpu, no gpu for reduce scatter

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Anastasiia Filippova
2025-11-05 17:21:11 +01:00
committed by GitHub
parent 761f901a41
commit 27778156dc
19 changed files with 351 additions and 311 deletions

View File

@@ -41,6 +41,14 @@ void recv(Group group, array& out, int src, Stream stream) {
group.raw_group()->recv(out, src, stream);
}
void sum_scatter(
Group group,
const array& input,
array& output,
Stream stream) {
group.raw_group()->sum_scatter(input, output, stream);
}
class EmptyGroup : public GroupImpl {
public:
Stream communication_stream(StreamOrDevice s) override {
@@ -85,6 +93,10 @@ class EmptyGroup : public GroupImpl {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
void sum_scatter(const array&, array&, Stream) override {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
};
} // namespace detail