mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
nccl default for backend=any (#2528)
* nccl default for backend=any * check num gpus + ensure row contiguous for all reduce * comment
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/distributed/mpi/mpi.h"
|
||||
@@ -114,7 +115,7 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
||||
}
|
||||
|
||||
// Create the requested communication group
|
||||
std::shared_ptr<detail::GroupImpl> group;
|
||||
std::shared_ptr<detail::GroupImpl> group{nullptr};
|
||||
std::string bk_ = bk;
|
||||
if (bk == "mpi") {
|
||||
group = mpi::init(strict);
|
||||
@@ -123,8 +124,14 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
||||
} else if (bk == "nccl") {
|
||||
group = nccl::init(strict);
|
||||
} else if (bk == "any") {
|
||||
group = ring::init(false);
|
||||
bk_ = "ring";
|
||||
if (mlx::core::cu::is_available()) {
|
||||
group = nccl::init(false);
|
||||
bk_ = "nccl";
|
||||
}
|
||||
if (group == nullptr) {
|
||||
group = ring::init(false);
|
||||
bk_ = "ring";
|
||||
}
|
||||
if (group == nullptr) {
|
||||
group = mpi::init(false);
|
||||
bk_ = "mpi";
|
||||
|
||||
Reference in New Issue
Block a user