Compare commits

..

1 Commits

Author SHA1 Message Date
Anastasiia Filippova
9efabb380c
Merge e6ae350999 into 76831ed83d 2025-06-20 08:46:25 +08:00

View File

@ -80,7 +80,7 @@ inline void bootstrap_unique_id(
int rank, int rank,
int size, int size,
const std::string& initMethod) { const std::string& initMethod) {
// Parse the init method to extract the host and port
if (initMethod.rfind("tcp://", 0) != 0) if (initMethod.rfind("tcp://", 0) != 0)
throw; throw;
auto hostport = initMethod.substr(6); auto hostport = initMethod.substr(6);
@ -89,10 +89,8 @@ inline void bootstrap_unique_id(
int port = std::stoi(hostport.substr(colon + 1)); int port = std::stoi(hostport.substr(colon + 1));
if (rank == 0) { if (rank == 0) {
// create a unique id on the rank 0
CHECK_NCCL(ncclGetUniqueId(&id)); CHECK_NCCL(ncclGetUniqueId(&id));
// create a socket to send the unique id to all other ranks
int sock = socket(AF_INET, SOCK_STREAM, 0); int sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) { if (sock < 0) {
@ -107,9 +105,6 @@ inline void bootstrap_unique_id(
serv.sin_port = htons(port); serv.sin_port = htons(port);
int reuse = 1; int reuse = 1;
// Without this, if rank-0 crashes or restarts process quickly,
// the OS might refuse to let binding to the same port, so reuse
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) { if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) {
std::ostringstream msg; std::ostringstream msg;
msg << "[nccl] setsockopt() failed: " << strerror(errno); msg << "[nccl] setsockopt() failed: " << strerror(errno);
@ -249,7 +244,9 @@ class NCCLGroup : public GroupImpl {
int ndev; int ndev;
CHECK_CUDA(cudaGetDeviceCount(&ndev)); CHECK_CUDA(cudaGetDeviceCount(&ndev));
CHECK_CUDA(cudaSetDevice(rank_ % ndev)); CHECK_CUDA(cudaSetDevice(rank_ % ndev));
detail::bootstrap_unique_id(uniqueId_, rank_, size_, initMethod_); CHECK_CUDA(cudaStreamCreate(&stream_));
detail::bootstrapUniqueId(uniqueId_, rank_, size_, initMethod_);
CHECK_NCCL(ncclCommInitRank(&comm_, size_, uniqueId_, rank_)); CHECK_NCCL(ncclCommInitRank(&comm_, size_, uniqueId_, rank_));
initialized_ = true; initialized_ = true;
} }
@ -257,6 +254,7 @@ class NCCLGroup : public GroupImpl {
~NCCLGroup() { ~NCCLGroup() {
ncclCommDestroy(comm_); ncclCommDestroy(comm_);
ncclGroupEnd(); ncclGroupEnd();
cudaStreamDestroy(stream_);
initialized_ = false; initialized_ = false;
} }
@ -269,9 +267,13 @@ class NCCLGroup : public GroupImpl {
} }
void all_sum(const array& input, array& output, Stream stream) override { void all_sum(const array& input, array& output, Stream stream) override {
if (input.size() != output.size()) {
throw std::runtime_error(
"[nccl] Input and output arrays must have the same size.");
}
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type; using T = typename decltype(type_tag)::type;
detail::all_reduce_impl<T>(input, output, stream, dt, ncclSum); all_reduce_impl<T>(input, output, stream, dt, ncclSum);
}); });
} }
@ -280,45 +282,29 @@ class NCCLGroup : public GroupImpl {
} }
void all_gather(const array& input, array& output, Stream stream) override { void all_gather(const array& input, array& output, Stream stream) override {
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { if (input.size() != output.size() / size_) {
using T = typename decltype(type_tag)::type; throw std::runtime_error(
CHECK_NCCL(ncclAllGather( "[nccl] Input size must match output size divided by group size.");
input.data<T>(), }
output.data<T>(),
input.size(),
dt,
comm_,
cu::get_stream(stream).last_cuda_stream()));
});
} }
void send(const array& input, int dst, Stream stream) override { void send(const array& input, int dst, Stream stream) override {
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { if (input.size() == 0) {
using T = typename decltype(type_tag)::type; return; // Nothing to send
CHECK_NCCL(ncclSend( }
input.data<T>(),
input.size(),
dt,
dst,
comm_,
cu::get_stream(stream).last_cuda_stream()));
});
} }
void recv(array& output, int src, Stream stream) override { void recv(array& output, int src, Stream stream) override {
detail::dispatch_dtype(output, [&](auto type_tag, ncclDataType_t dt) { if (output.size() == 0) {
using T = typename decltype(type_tag)::type; return; // Nothing to receive
CHECK_NCCL(ncclRecv( }
output.data<T>(),
output.size(),
dt,
src,
comm_,
cu::get_stream(stream).last_cuda_stream()));
});
} }
void all_max(const array& input, array& output, Stream stream) override { void all_max(const array& input, array& output, Stream stream) override {
if (input.size() != output.size()) {
throw std::runtime_error(
"[nccl] Input and output arrays must have the same size.");
}
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type; using T = typename decltype(type_tag)::type;
all_reduce_impl<T>(input, output, stream, dt, ncclMax); all_reduce_impl<T>(input, output, stream, dt, ncclMax);
@ -326,6 +312,10 @@ class NCCLGroup : public GroupImpl {
} }
void all_min(const array& input, array& output, Stream stream) override { void all_min(const array& input, array& output, Stream stream) override {
if (input.size() != output.size()) {
throw std::runtime_error(
"[nccl] Input and output arrays must have the same size.");
}
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type; using T = typename decltype(type_tag)::type;
all_reduce_impl<T>(input, output, stream, dt, ncclMin); all_reduce_impl<T>(input, output, stream, dt, ncclMin);
@ -339,6 +329,7 @@ class NCCLGroup : public GroupImpl {
Stream stream, Stream stream,
ncclDataType_t dt, ncclDataType_t dt,
ncclRedOp_t op) { ncclRedOp_t op) {
CHECK_NCCL(ncclAllReduce( CHECK_NCCL(ncclAllReduce(
input.data<T>(), input.data<T>(),
output.data<T>(), output.data<T>(),
@ -346,13 +337,15 @@ class NCCLGroup : public GroupImpl {
dt, dt,
op, op,
comm_, comm_,
cu::get_stream(stream).last_cuda_stream())); stream_));
cudaStreamSynchronize(stream_);
} }
int rank_, size_; int rank_, size_;
std::string initMethod_; std::string initMethod_;
ncclUniqueId uniqueId_; ncclUniqueId uniqueId_;
ncclComm_t comm_; ncclComm_t comm_;
cudaStream_t stream_;
bool initialized_ = false; bool initialized_ = false;
}; };