mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 00:36:32 +08:00
nccl default for backend=any
This commit is contained in:
parent
5722c147de
commit
1eb589cd77
@ -405,6 +405,7 @@ jobs:
|
||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||
sudo apt-get update
|
||||
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
|
||||
sudo apt-get install libnccl2 libnccl-dev
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install zip
|
||||
pip install auditwheel
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/distributed/mpi/mpi.h"
|
||||
@ -114,7 +115,7 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
||||
}
|
||||
|
||||
// Create the requested communication group
|
||||
std::shared_ptr<detail::GroupImpl> group;
|
||||
std::shared_ptr<detail::GroupImpl> group{nullptr};
|
||||
std::string bk_ = bk;
|
||||
if (bk == "mpi") {
|
||||
group = mpi::init(strict);
|
||||
@ -123,8 +124,14 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
||||
} else if (bk == "nccl") {
|
||||
group = nccl::init(strict);
|
||||
} else if (bk == "any") {
|
||||
group = ring::init(false);
|
||||
bk_ = "ring";
|
||||
if (mlx::core::cu::is_available()) {
|
||||
group = nccl::init(false);
|
||||
bk_ = "nccl";
|
||||
}
|
||||
if (group == nullptr) {
|
||||
group = ring::init(false);
|
||||
bk_ = "ring";
|
||||
}
|
||||
if (group == nullptr) {
|
||||
group = mpi::init(false);
|
||||
bk_ = "mpi";
|
||||
|
@ -204,13 +204,17 @@ inline void bootstrap_unique_id(
|
||||
int attempt = 0;
|
||||
bool connected = false;
|
||||
|
||||
bool do_log = std::getenv("NCCL_DEBUG") == "INFO";
|
||||
for (attempt = 0; attempt < max_retries; ++attempt) {
|
||||
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
|
||||
0) {
|
||||
connected = true;
|
||||
std::cout << "[Rank " << rank << "] Connected successfully on attempt "
|
||||
<< attempt + 1 << std::endl;
|
||||
break;
|
||||
if (do_log) {
|
||||
std::cout << "[Rank " << rank
|
||||
<< "] Connected successfully on attempt " << attempt + 1
|
||||
<< std::endl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (errno != ECONNREFUSED) {
|
||||
break;
|
||||
@ -331,24 +335,33 @@ bool is_available() {
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
static std::string get_env_var_or_throw(const char* env_var_name) {
|
||||
std::string get_env_var_or_throw(const char* env_var_name, bool strict) {
|
||||
const char* value = std::getenv(env_var_name);
|
||||
if (value == nullptr) {
|
||||
if (value == nullptr && strict) {
|
||||
std::ostringstream msg;
|
||||
msg << "[nccl] Required environment variable '" << env_var_name
|
||||
<< "' is not set. "
|
||||
<< "Please set it before initializing the distributed backend.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
if (value == nullptr) {
|
||||
return "";
|
||||
}
|
||||
return std::string(value);
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||
std::string host = detail::get_env_var_or_throw("NCCL_HOST_IP");
|
||||
std::string port = detail::get_env_var_or_throw("NCCL_PORT");
|
||||
std::string rank_str = detail::get_env_var_or_throw("MLX_RANK");
|
||||
std::string n_nodes_str = detail::get_env_var_or_throw("MLX_WORLD_SIZE");
|
||||
std::string host = detail::get_env_var_or_throw("NCCL_HOST_IP", strict);
|
||||
std::string port = detail::get_env_var_or_throw("NCCL_PORT", strict);
|
||||
std::string rank_str = detail::get_env_var_or_throw("MLX_RANK", strict);
|
||||
std::string n_nodes_str =
|
||||
detail::get_env_var_or_throw("MLX_WORLD_SIZE", strict);
|
||||
if (!strict &&
|
||||
(host.empty() || port.empty() || rank_str.empty() ||
|
||||
n_nodes_str.empty())) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int rank = std::stoi(rank_str);
|
||||
int n_nodes = std::stoi(n_nodes_str);
|
||||
|
@ -428,7 +428,7 @@ def launch_nccl(parser, hosts, args, command):
|
||||
base_env = os.environ.copy()
|
||||
base_env.update(
|
||||
{
|
||||
"NCCL_DEBUG": "INFO",
|
||||
"NCCL_DEBUG": base_env.get("NCCL_DEBUG", "DEBUG"),
|
||||
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
|
||||
"NCCL_HOST_IP": master_host,
|
||||
"NCCL_PORT": str(master_port),
|
||||
@ -821,8 +821,6 @@ def main():
|
||||
)
|
||||
|
||||
args, rest = parser.parse_known_args()
|
||||
if rest[0] == "--":
|
||||
rest.pop(0)
|
||||
|
||||
if args.print_python:
|
||||
print(sys.executable)
|
||||
|
Loading…
Reference in New Issue
Block a user