NCCL backend (#2476)

This commit is contained in:
Anastasiia Filippova
2025-08-21 20:56:15 +02:00
committed by GitHub
parent e843c4d8d5
commit 9392fc3f88
21 changed files with 897 additions and 20 deletions

View File

@@ -5,12 +5,17 @@
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/mpi/mpi.h"
#include "mlx/distributed/nccl/nccl.h"
#include "mlx/distributed/ring/ring.h"
namespace mlx::core::distributed {
namespace detail {
Stream communication_stream(Group group, StreamOrDevice s /* = {} */) {
return group.raw_group()->communication_stream(s);
}
void all_sum(Group group, const array& input, array& output, Stream stream) {
group.raw_group()->all_sum(input, output, stream);
}
@@ -37,6 +42,10 @@ void recv(Group group, array& out, int src, Stream stream) {
class EmptyGroup : public GroupImpl {
public:
Stream communication_stream(StreamOrDevice s) override {
return to_stream(s);
}
int rank() override {
return 0;
}
@@ -80,7 +89,7 @@ class EmptyGroup : public GroupImpl {
} // namespace detail
bool is_available() {
return mpi::is_available() || ring::is_available();
return mpi::is_available() || ring::is_available() || nccl::is_available();
}
int Group::rank() const {
@@ -111,6 +120,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
group = mpi::init(strict);
} else if (bk == "ring") {
group = ring::init(strict);
} else if (bk == "nccl") {
group = nccl::init(strict);
} else if (bk == "any") {
group = ring::init(false);
bk_ = "ring";