Nccl timeout (#2673)

* print the error & delete nccl group

* timeout for nccl binding

* typo

* revert error

* fixed a typo
This commit is contained in:
Anastasiia Filippova
2025-10-14 21:29:54 +02:00
committed by GitHub
parent 36ca62dba8
commit e9eab527eb
2 changed files with 29 additions and 12 deletions

View File

@@ -21,6 +21,9 @@
namespace mlx::core::distributed::nccl {
// Can be tuned with MLX_NCCL_TIMEOUT
constexpr int nccl_timeout = 300000; // miliseconds
#define CHECK_CUDA(cmd) \
do { \
cudaError_t e = cmd; \
@@ -181,8 +184,9 @@ inline void bootstrap_unique_id(
close(sock);
} else {
// Here just wanted to make show that rank 0 has enough time to bind
// so we will retry to connect until max attempts
// Here we want to make sure that rank 0 has enough time to bind
// so we will retry to connect until elapsed time exceeds nccl_timeout
// this is particularity important for multinode setup
int sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
@@ -200,32 +204,41 @@ inline void bootstrap_unique_id(
memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length);
serv.sin_port = htons(port);
const int max_retries = 30;
int attempt = 0;
const int timeout_ms = env::nccl_timeout(nccl_timeout);
bool connected = false;
bool do_log = std::getenv("NCCL_DEBUG") == "INFO";
for (attempt = 0; attempt < max_retries; ++attempt) {
const char* dbg = std::getenv("NCCL_DEBUG");
bool do_log = (dbg && std::string(dbg) == "INFO");
auto start = std::chrono::steady_clock::now();
int attempt = 0;
while (true) {
auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now() - start)
.count();
if (elapsed_ms > timeout_ms)
break;
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
0) {
connected = true;
if (do_log) {
std::cout << "[Rank " << rank
<< "] Connected successfully on attempt " << attempt + 1
<< std::endl;
std::cout << "[Rank " << rank << "] Connected successfully after "
<< elapsed_ms << " miliseconds" << std::endl;
break;
}
}
if (errno != ECONNREFUSED) {
break;
}
++attempt;
std::this_thread::sleep_for(std::chrono::milliseconds(500));
}
if (!connected) {
std::ostringstream msg;
msg << "[Rank " << rank << "] connect() failed after " << attempt
<< " retries: " << strerror(errno);
msg << "[Rank " << rank << "] connect() failed after " << timeout_ms
<< " milliseconds and " << attempt << " retries: " << strerror(errno);
close(sock);
throw std::runtime_error(msg.str());
}
@@ -256,7 +269,6 @@ class NCCLGroup : public GroupImpl {
~NCCLGroup() {
ncclCommDestroy(comm_);
ncclGroupEnd();
initialized_ = false;
}

View File

@@ -165,6 +165,11 @@ inline bool enable_tf32() {
return enable_tf32_;
}
inline int nccl_timeout(int default_value) {
static int nccl_timeout = get_var("MLX_NCCL_TIMEOUT", default_value);
return nccl_timeout;
}
} // namespace env
} // namespace mlx::core