Compare commits

...

2 Commits

Author SHA1 Message Date
Awni Hannun
4bce5f9b2d suppress gcc 10.1 warnings (#2679)
* suppress gcc 10.1 warnings

* suppress gcc 10.1 warnings
2025-10-17 12:09:21 -07:00
Anastasiia Filippova
e9eab527eb Nccl timeout (#2673)
* print the error & delete nccl group

* timeout for nccl binding

* typo

* revert error

* fixed a typo
2025-10-14 12:29:54 -07:00
3 changed files with 33 additions and 12 deletions

View File

@@ -170,6 +170,10 @@ target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
# Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>)
# Supress warnings: note: parameter passing for argument of type
# std::pair<float, float> when C++17 is enabled changed to match C++14 in GCC
# 10.1
target_compile_options(mlx PRIVATE -Wno-psabi)
# Install CCCL headers for JIT.
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda

View File

@@ -21,6 +21,9 @@
namespace mlx::core::distributed::nccl {
// Can be tuned with MLX_NCCL_TIMEOUT
constexpr int nccl_timeout = 300000; // miliseconds
#define CHECK_CUDA(cmd) \
do { \
cudaError_t e = cmd; \
@@ -181,8 +184,9 @@ inline void bootstrap_unique_id(
close(sock);
} else {
// Here just wanted to make show that rank 0 has enough time to bind
// so we will retry to connect until max attempts
// Here we want to make sure that rank 0 has enough time to bind
// so we will retry to connect until elapsed time exceeds nccl_timeout
// this is particularity important for multinode setup
int sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
@@ -200,32 +204,41 @@ inline void bootstrap_unique_id(
memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length);
serv.sin_port = htons(port);
const int max_retries = 30;
int attempt = 0;
const int timeout_ms = env::nccl_timeout(nccl_timeout);
bool connected = false;
bool do_log = std::getenv("NCCL_DEBUG") == "INFO";
for (attempt = 0; attempt < max_retries; ++attempt) {
const char* dbg = std::getenv("NCCL_DEBUG");
bool do_log = (dbg && std::string(dbg) == "INFO");
auto start = std::chrono::steady_clock::now();
int attempt = 0;
while (true) {
auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now() - start)
.count();
if (elapsed_ms > timeout_ms)
break;
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
0) {
connected = true;
if (do_log) {
std::cout << "[Rank " << rank
<< "] Connected successfully on attempt " << attempt + 1
<< std::endl;
std::cout << "[Rank " << rank << "] Connected successfully after "
<< elapsed_ms << " miliseconds" << std::endl;
break;
}
}
if (errno != ECONNREFUSED) {
break;
}
++attempt;
std::this_thread::sleep_for(std::chrono::milliseconds(500));
}
if (!connected) {
std::ostringstream msg;
msg << "[Rank " << rank << "] connect() failed after " << attempt
<< " retries: " << strerror(errno);
msg << "[Rank " << rank << "] connect() failed after " << timeout_ms
<< " milliseconds and " << attempt << " retries: " << strerror(errno);
close(sock);
throw std::runtime_error(msg.str());
}
@@ -256,7 +269,6 @@ class NCCLGroup : public GroupImpl {
~NCCLGroup() {
ncclCommDestroy(comm_);
ncclGroupEnd();
initialized_ = false;
}

View File

@@ -165,6 +165,11 @@ inline bool enable_tf32() {
return enable_tf32_;
}
inline int nccl_timeout(int default_value) {
static int nccl_timeout = get_var("MLX_NCCL_TIMEOUT", default_value);
return nccl_timeout;
}
} // namespace env
} // namespace mlx::core