mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-17 06:38:38 +08:00
Nccl timeout (#2673)
* print the error & delete nccl group * timeout for nccl binding * typo * revert error * fixed a typo
This commit is contained in:

committed by
GitHub

parent
36ca62dba8
commit
e9eab527eb
@@ -21,6 +21,9 @@
|
|||||||
|
|
||||||
namespace mlx::core::distributed::nccl {
|
namespace mlx::core::distributed::nccl {
|
||||||
|
|
||||||
|
// Can be tuned with MLX_NCCL_TIMEOUT
|
||||||
|
constexpr int nccl_timeout = 300000; // miliseconds
|
||||||
|
|
||||||
#define CHECK_CUDA(cmd) \
|
#define CHECK_CUDA(cmd) \
|
||||||
do { \
|
do { \
|
||||||
cudaError_t e = cmd; \
|
cudaError_t e = cmd; \
|
||||||
@@ -181,8 +184,9 @@ inline void bootstrap_unique_id(
|
|||||||
close(sock);
|
close(sock);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// Here just wanted to make show that rank 0 has enough time to bind
|
// Here we want to make sure that rank 0 has enough time to bind
|
||||||
// so we will retry to connect until max attempts
|
// so we will retry to connect until elapsed time exceeds nccl_timeout
|
||||||
|
// this is particularity important for multinode setup
|
||||||
|
|
||||||
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||||
if (sock < 0) {
|
if (sock < 0) {
|
||||||
@@ -200,32 +204,41 @@ inline void bootstrap_unique_id(
|
|||||||
memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length);
|
memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length);
|
||||||
serv.sin_port = htons(port);
|
serv.sin_port = htons(port);
|
||||||
|
|
||||||
const int max_retries = 30;
|
const int timeout_ms = env::nccl_timeout(nccl_timeout);
|
||||||
int attempt = 0;
|
|
||||||
bool connected = false;
|
bool connected = false;
|
||||||
|
|
||||||
bool do_log = std::getenv("NCCL_DEBUG") == "INFO";
|
const char* dbg = std::getenv("NCCL_DEBUG");
|
||||||
for (attempt = 0; attempt < max_retries; ++attempt) {
|
bool do_log = (dbg && std::string(dbg) == "INFO");
|
||||||
|
|
||||||
|
auto start = std::chrono::steady_clock::now();
|
||||||
|
int attempt = 0;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||||
|
std::chrono::steady_clock::now() - start)
|
||||||
|
.count();
|
||||||
|
if (elapsed_ms > timeout_ms)
|
||||||
|
break;
|
||||||
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
|
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
|
||||||
0) {
|
0) {
|
||||||
connected = true;
|
connected = true;
|
||||||
if (do_log) {
|
if (do_log) {
|
||||||
std::cout << "[Rank " << rank
|
std::cout << "[Rank " << rank << "] Connected successfully after "
|
||||||
<< "] Connected successfully on attempt " << attempt + 1
|
<< elapsed_ms << " miliseconds" << std::endl;
|
||||||
<< std::endl;
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (errno != ECONNREFUSED) {
|
if (errno != ECONNREFUSED) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
++attempt;
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!connected) {
|
if (!connected) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[Rank " << rank << "] connect() failed after " << attempt
|
msg << "[Rank " << rank << "] connect() failed after " << timeout_ms
|
||||||
<< " retries: " << strerror(errno);
|
<< " milliseconds and " << attempt << " retries: " << strerror(errno);
|
||||||
close(sock);
|
close(sock);
|
||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
@@ -256,7 +269,6 @@ class NCCLGroup : public GroupImpl {
|
|||||||
|
|
||||||
~NCCLGroup() {
|
~NCCLGroup() {
|
||||||
ncclCommDestroy(comm_);
|
ncclCommDestroy(comm_);
|
||||||
ncclGroupEnd();
|
|
||||||
initialized_ = false;
|
initialized_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -165,6 +165,11 @@ inline bool enable_tf32() {
|
|||||||
return enable_tf32_;
|
return enable_tf32_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline int nccl_timeout(int default_value) {
|
||||||
|
static int nccl_timeout = get_var("MLX_NCCL_TIMEOUT", default_value);
|
||||||
|
return nccl_timeout;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace env
|
} // namespace env
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
Reference in New Issue
Block a user