nccl default for backend=any (#2528)

* nccl default for backend=any

* check num gpus + ensure row contiguous for all reduce

* comment
This commit is contained in:
Awni Hannun
2025-08-22 12:24:27 -07:00
committed by GitHub
parent 5722c147de
commit 068a4612e9
5 changed files with 68 additions and 31 deletions

View File

@@ -2,6 +2,7 @@
#include <unordered_map>
#include "mlx/backend/cuda/cuda.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/mpi/mpi.h"
@@ -114,7 +115,7 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
}
// Create the requested communication group
std::shared_ptr<detail::GroupImpl> group;
std::shared_ptr<detail::GroupImpl> group{nullptr};
std::string bk_ = bk;
if (bk == "mpi") {
group = mpi::init(strict);
@@ -123,8 +124,14 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
} else if (bk == "nccl") {
group = nccl::init(strict);
} else if (bk == "any") {
group = ring::init(false);
bk_ = "ring";
if (mlx::core::cu::is_available()) {
group = nccl::init(false);
bk_ = "nccl";
}
if (group == nullptr) {
group = ring::init(false);
bk_ = "ring";
}
if (group == nullptr) {
group = mpi::init(false);
bk_ = "mpi";