diff --git a/.circleci/config.yml b/.circleci/config.yml index 03987d39c..70f4e0fe5 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -405,6 +405,7 @@ jobs: sudo dpkg -i cuda-keyring_1.1-1_all.deb sudo apt-get update sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12 + sudo apt-get install libnccl2 libnccl-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install zip pip install auditwheel diff --git a/mlx/distributed/distributed.cpp b/mlx/distributed/distributed.cpp index 44205e87e..d71ebb9b1 100644 --- a/mlx/distributed/distributed.cpp +++ b/mlx/distributed/distributed.cpp @@ -2,6 +2,7 @@ #include +#include "mlx/backend/cuda/cuda.h" #include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/mpi/mpi.h" @@ -114,7 +115,7 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) { } // Create the requested communication group - std::shared_ptr group; + std::shared_ptr group{nullptr}; std::string bk_ = bk; if (bk == "mpi") { group = mpi::init(strict); @@ -123,8 +124,14 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) { } else if (bk == "nccl") { group = nccl::init(strict); } else if (bk == "any") { - group = ring::init(false); - bk_ = "ring"; + if (mlx::core::cu::is_available()) { + group = nccl::init(false); + bk_ = "nccl"; + } + if (group == nullptr) { + group = ring::init(false); + bk_ = "ring"; + } if (group == nullptr) { group = mpi::init(false); bk_ = "mpi"; diff --git a/mlx/distributed/nccl/nccl.cpp b/mlx/distributed/nccl/nccl.cpp index 43af9c724..751ba9130 100644 --- a/mlx/distributed/nccl/nccl.cpp +++ b/mlx/distributed/nccl/nccl.cpp @@ -204,13 +204,17 @@ inline void bootstrap_unique_id( int attempt = 0; bool connected = false; + bool do_log = std::getenv("NCCL_DEBUG") == "INFO"; for (attempt = 0; attempt < max_retries; ++attempt) { if (connect(sock, reinterpret_cast(&serv), sizeof(serv)) == 0) { connected = true; - std::cout << "[Rank " << rank << "] Connected successfully on attempt " - << attempt + 1 << std::endl; - break; + if (do_log) { + std::cout << "[Rank " << rank + << "] Connected successfully on attempt " << attempt + 1 + << std::endl; + break; + } } if (errno != ECONNREFUSED) { break; @@ -331,24 +335,33 @@ bool is_available() { } namespace detail { -static std::string get_env_var_or_throw(const char* env_var_name) { +std::string get_env_var_or_throw(const char* env_var_name, bool strict) { const char* value = std::getenv(env_var_name); - if (value == nullptr) { + if (value == nullptr && strict) { std::ostringstream msg; msg << "[nccl] Required environment variable '" << env_var_name << "' is not set. " << "Please set it before initializing the distributed backend."; throw std::runtime_error(msg.str()); } + if (value == nullptr) { + return ""; + } return std::string(value); } } // namespace detail std::shared_ptr init(bool strict /* = false */) { - std::string host = detail::get_env_var_or_throw("NCCL_HOST_IP"); - std::string port = detail::get_env_var_or_throw("NCCL_PORT"); - std::string rank_str = detail::get_env_var_or_throw("MLX_RANK"); - std::string n_nodes_str = detail::get_env_var_or_throw("MLX_WORLD_SIZE"); + std::string host = detail::get_env_var_or_throw("NCCL_HOST_IP", strict); + std::string port = detail::get_env_var_or_throw("NCCL_PORT", strict); + std::string rank_str = detail::get_env_var_or_throw("MLX_RANK", strict); + std::string n_nodes_str = + detail::get_env_var_or_throw("MLX_WORLD_SIZE", strict); + if (!strict && + (host.empty() || port.empty() || rank_str.empty() || + n_nodes_str.empty())) { + return nullptr; + } int rank = std::stoi(rank_str); int n_nodes = std::stoi(n_nodes_str); diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index 3dd373eb1..a588326de 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -428,7 +428,7 @@ def launch_nccl(parser, hosts, args, command): base_env = os.environ.copy() base_env.update( { - "NCCL_DEBUG": "INFO", + "NCCL_DEBUG": base_env.get("NCCL_DEBUG", "DEBUG"), "NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication "NCCL_HOST_IP": master_host, "NCCL_PORT": str(master_port), @@ -821,8 +821,6 @@ def main(): ) args, rest = parser.parse_known_args() - if rest[0] == "--": - rest.pop(0) if args.print_python: print(sys.executable)