mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
2 Commits
36ca62dba8
...
v0.29.3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4bce5f9b2d | ||
|
|
e9eab527eb |
@@ -170,6 +170,10 @@ target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
|
|||||||
# Suppress nvcc warnings on MLX headers.
|
# Suppress nvcc warnings on MLX headers.
|
||||||
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||||
--diag_suppress=997>)
|
--diag_suppress=997>)
|
||||||
|
# Supress warnings: note: parameter passing for argument of type
|
||||||
|
# ‘std::pair<float, float>’ when C++17 is enabled changed to match C++14 in GCC
|
||||||
|
# 10.1
|
||||||
|
target_compile_options(mlx PRIVATE -Wno-psabi)
|
||||||
|
|
||||||
# Install CCCL headers for JIT.
|
# Install CCCL headers for JIT.
|
||||||
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
|
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
|
||||||
|
|||||||
@@ -21,6 +21,9 @@
|
|||||||
|
|
||||||
namespace mlx::core::distributed::nccl {
|
namespace mlx::core::distributed::nccl {
|
||||||
|
|
||||||
|
// Can be tuned with MLX_NCCL_TIMEOUT
|
||||||
|
constexpr int nccl_timeout = 300000; // miliseconds
|
||||||
|
|
||||||
#define CHECK_CUDA(cmd) \
|
#define CHECK_CUDA(cmd) \
|
||||||
do { \
|
do { \
|
||||||
cudaError_t e = cmd; \
|
cudaError_t e = cmd; \
|
||||||
@@ -181,8 +184,9 @@ inline void bootstrap_unique_id(
|
|||||||
close(sock);
|
close(sock);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// Here just wanted to make show that rank 0 has enough time to bind
|
// Here we want to make sure that rank 0 has enough time to bind
|
||||||
// so we will retry to connect until max attempts
|
// so we will retry to connect until elapsed time exceeds nccl_timeout
|
||||||
|
// this is particularity important for multinode setup
|
||||||
|
|
||||||
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||||
if (sock < 0) {
|
if (sock < 0) {
|
||||||
@@ -200,32 +204,41 @@ inline void bootstrap_unique_id(
|
|||||||
memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length);
|
memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length);
|
||||||
serv.sin_port = htons(port);
|
serv.sin_port = htons(port);
|
||||||
|
|
||||||
const int max_retries = 30;
|
const int timeout_ms = env::nccl_timeout(nccl_timeout);
|
||||||
int attempt = 0;
|
|
||||||
bool connected = false;
|
bool connected = false;
|
||||||
|
|
||||||
bool do_log = std::getenv("NCCL_DEBUG") == "INFO";
|
const char* dbg = std::getenv("NCCL_DEBUG");
|
||||||
for (attempt = 0; attempt < max_retries; ++attempt) {
|
bool do_log = (dbg && std::string(dbg) == "INFO");
|
||||||
|
|
||||||
|
auto start = std::chrono::steady_clock::now();
|
||||||
|
int attempt = 0;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||||
|
std::chrono::steady_clock::now() - start)
|
||||||
|
.count();
|
||||||
|
if (elapsed_ms > timeout_ms)
|
||||||
|
break;
|
||||||
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
|
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
|
||||||
0) {
|
0) {
|
||||||
connected = true;
|
connected = true;
|
||||||
if (do_log) {
|
if (do_log) {
|
||||||
std::cout << "[Rank " << rank
|
std::cout << "[Rank " << rank << "] Connected successfully after "
|
||||||
<< "] Connected successfully on attempt " << attempt + 1
|
<< elapsed_ms << " miliseconds" << std::endl;
|
||||||
<< std::endl;
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (errno != ECONNREFUSED) {
|
if (errno != ECONNREFUSED) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
++attempt;
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!connected) {
|
if (!connected) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[Rank " << rank << "] connect() failed after " << attempt
|
msg << "[Rank " << rank << "] connect() failed after " << timeout_ms
|
||||||
<< " retries: " << strerror(errno);
|
<< " milliseconds and " << attempt << " retries: " << strerror(errno);
|
||||||
close(sock);
|
close(sock);
|
||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
@@ -256,7 +269,6 @@ class NCCLGroup : public GroupImpl {
|
|||||||
|
|
||||||
~NCCLGroup() {
|
~NCCLGroup() {
|
||||||
ncclCommDestroy(comm_);
|
ncclCommDestroy(comm_);
|
||||||
ncclGroupEnd();
|
|
||||||
initialized_ = false;
|
initialized_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -165,6 +165,11 @@ inline bool enable_tf32() {
|
|||||||
return enable_tf32_;
|
return enable_tf32_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline int nccl_timeout(int default_value) {
|
||||||
|
static int nccl_timeout = get_var("MLX_NCCL_TIMEOUT", default_value);
|
||||||
|
return nccl_timeout;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace env
|
} // namespace env
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
Reference in New Issue
Block a user