mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-16 22:28:12 +08:00
Nccl timeout (#2673)
* print the error & delete nccl group * timeout for nccl binding * typo * revert error * fixed a typo
This commit is contained in:

committed by
GitHub

parent
36ca62dba8
commit
e9eab527eb
@@ -21,6 +21,9 @@
|
||||
|
||||
namespace mlx::core::distributed::nccl {
|
||||
|
||||
// Can be tuned with MLX_NCCL_TIMEOUT
|
||||
constexpr int nccl_timeout = 300000; // miliseconds
|
||||
|
||||
#define CHECK_CUDA(cmd) \
|
||||
do { \
|
||||
cudaError_t e = cmd; \
|
||||
@@ -181,8 +184,9 @@ inline void bootstrap_unique_id(
|
||||
close(sock);
|
||||
|
||||
} else {
|
||||
// Here just wanted to make show that rank 0 has enough time to bind
|
||||
// so we will retry to connect until max attempts
|
||||
// Here we want to make sure that rank 0 has enough time to bind
|
||||
// so we will retry to connect until elapsed time exceeds nccl_timeout
|
||||
// this is particularity important for multinode setup
|
||||
|
||||
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sock < 0) {
|
||||
@@ -200,32 +204,41 @@ inline void bootstrap_unique_id(
|
||||
memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length);
|
||||
serv.sin_port = htons(port);
|
||||
|
||||
const int max_retries = 30;
|
||||
int attempt = 0;
|
||||
const int timeout_ms = env::nccl_timeout(nccl_timeout);
|
||||
bool connected = false;
|
||||
|
||||
bool do_log = std::getenv("NCCL_DEBUG") == "INFO";
|
||||
for (attempt = 0; attempt < max_retries; ++attempt) {
|
||||
const char* dbg = std::getenv("NCCL_DEBUG");
|
||||
bool do_log = (dbg && std::string(dbg) == "INFO");
|
||||
|
||||
auto start = std::chrono::steady_clock::now();
|
||||
int attempt = 0;
|
||||
|
||||
while (true) {
|
||||
auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
std::chrono::steady_clock::now() - start)
|
||||
.count();
|
||||
if (elapsed_ms > timeout_ms)
|
||||
break;
|
||||
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
|
||||
0) {
|
||||
connected = true;
|
||||
if (do_log) {
|
||||
std::cout << "[Rank " << rank
|
||||
<< "] Connected successfully on attempt " << attempt + 1
|
||||
<< std::endl;
|
||||
std::cout << "[Rank " << rank << "] Connected successfully after "
|
||||
<< elapsed_ms << " miliseconds" << std::endl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (errno != ECONNREFUSED) {
|
||||
break;
|
||||
}
|
||||
++attempt;
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
||||
}
|
||||
|
||||
if (!connected) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Rank " << rank << "] connect() failed after " << attempt
|
||||
<< " retries: " << strerror(errno);
|
||||
msg << "[Rank " << rank << "] connect() failed after " << timeout_ms
|
||||
<< " milliseconds and " << attempt << " retries: " << strerror(errno);
|
||||
close(sock);
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
@@ -256,7 +269,6 @@ class NCCLGroup : public GroupImpl {
|
||||
|
||||
~NCCLGroup() {
|
||||
ncclCommDestroy(comm_);
|
||||
ncclGroupEnd();
|
||||
initialized_ = false;
|
||||
}
|
||||
|
||||
|
@@ -165,6 +165,11 @@ inline bool enable_tf32() {
|
||||
return enable_tf32_;
|
||||
}
|
||||
|
||||
inline int nccl_timeout(int default_value) {
|
||||
static int nccl_timeout = get_var("MLX_NCCL_TIMEOUT", default_value);
|
||||
return nccl_timeout;
|
||||
}
|
||||
|
||||
} // namespace env
|
||||
|
||||
} // namespace mlx::core
|
||||
|
Reference in New Issue
Block a user