From e9eab527eb51076b1a30b8ebdd4a2c6bdb284701 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Tue, 14 Oct 2025 21:29:54 +0200 Subject: [PATCH] Nccl timeout (#2673) * print the error & delete nccl group * timeout for nccl binding * typo * revert error * fixed a typo --- mlx/distributed/nccl/nccl.cpp | 36 +++++++++++++++++++++++------------ mlx/utils.h | 5 +++++ 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/mlx/distributed/nccl/nccl.cpp b/mlx/distributed/nccl/nccl.cpp index 751ba9130..8a5376242 100644 --- a/mlx/distributed/nccl/nccl.cpp +++ b/mlx/distributed/nccl/nccl.cpp @@ -21,6 +21,9 @@ namespace mlx::core::distributed::nccl { +// Can be tuned with MLX_NCCL_TIMEOUT +constexpr int nccl_timeout = 300000; // miliseconds + #define CHECK_CUDA(cmd) \ do { \ cudaError_t e = cmd; \ @@ -181,8 +184,9 @@ inline void bootstrap_unique_id( close(sock); } else { - // Here just wanted to make show that rank 0 has enough time to bind - // so we will retry to connect until max attempts + // Here we want to make sure that rank 0 has enough time to bind + // so we will retry to connect until elapsed time exceeds nccl_timeout + // this is particularity important for multinode setup int sock = socket(AF_INET, SOCK_STREAM, 0); if (sock < 0) { @@ -200,32 +204,41 @@ inline void bootstrap_unique_id( memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length); serv.sin_port = htons(port); - const int max_retries = 30; - int attempt = 0; + const int timeout_ms = env::nccl_timeout(nccl_timeout); bool connected = false; - bool do_log = std::getenv("NCCL_DEBUG") == "INFO"; - for (attempt = 0; attempt < max_retries; ++attempt) { + const char* dbg = std::getenv("NCCL_DEBUG"); + bool do_log = (dbg && std::string(dbg) == "INFO"); + + auto start = std::chrono::steady_clock::now(); + int attempt = 0; + + while (true) { + auto elapsed_ms = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + if (elapsed_ms > timeout_ms) + break; if (connect(sock, reinterpret_cast(&serv), sizeof(serv)) == 0) { connected = true; if (do_log) { - std::cout << "[Rank " << rank - << "] Connected successfully on attempt " << attempt + 1 - << std::endl; + std::cout << "[Rank " << rank << "] Connected successfully after " + << elapsed_ms << " miliseconds" << std::endl; break; } } if (errno != ECONNREFUSED) { break; } + ++attempt; std::this_thread::sleep_for(std::chrono::milliseconds(500)); } if (!connected) { std::ostringstream msg; - msg << "[Rank " << rank << "] connect() failed after " << attempt - << " retries: " << strerror(errno); + msg << "[Rank " << rank << "] connect() failed after " << timeout_ms + << " milliseconds and " << attempt << " retries: " << strerror(errno); close(sock); throw std::runtime_error(msg.str()); } @@ -256,7 +269,6 @@ class NCCLGroup : public GroupImpl { ~NCCLGroup() { ncclCommDestroy(comm_); - ncclGroupEnd(); initialized_ = false; } diff --git a/mlx/utils.h b/mlx/utils.h index 076842f78..dbf79a71f 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -165,6 +165,11 @@ inline bool enable_tf32() { return enable_tf32_; } +inline int nccl_timeout(int default_value) { + static int nccl_timeout = get_var("MLX_NCCL_TIMEOUT", default_value); + return nccl_timeout; +} + } // namespace env } // namespace mlx::core