Use last cuda stream instead of new one

This commit is contained in:
Anastasiia Filippova 2025-06-20 16:07:41 +02:00
parent e6ae350999
commit 043c37cccd

View File

@ -80,7 +80,7 @@ inline void bootstrap_unique_id(
int rank,
int size,
const std::string& initMethod) {
// Parse the init method to extract the host and port
if (initMethod.rfind("tcp://", 0) != 0)
throw;
auto hostport = initMethod.substr(6);
@ -89,8 +89,10 @@ inline void bootstrap_unique_id(
int port = std::stoi(hostport.substr(colon + 1));
if (rank == 0) {
// create a unique id on the rank 0
CHECK_NCCL(ncclGetUniqueId(&id));
// create a socket to send the unique id to all other ranks
int sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
@ -105,6 +107,9 @@ inline void bootstrap_unique_id(
serv.sin_port = htons(port);
int reuse = 1;
// Without this, if rank-0 crashes or restarts process quickly,
// the OS might refuse to let binding to the same port, so reuse
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) {
std::ostringstream msg;
msg << "[nccl] setsockopt() failed: " << strerror(errno);
@ -244,9 +249,7 @@ class NCCLGroup : public GroupImpl {
int ndev;
CHECK_CUDA(cudaGetDeviceCount(&ndev));
CHECK_CUDA(cudaSetDevice(rank_ % ndev));
CHECK_CUDA(cudaStreamCreate(&stream_));
detail::bootstrapUniqueId(uniqueId_, rank_, size_, initMethod_);
detail::bootstrap_unique_id(uniqueId_, rank_, size_, initMethod_);
CHECK_NCCL(ncclCommInitRank(&comm_, size_, uniqueId_, rank_));
initialized_ = true;
}
@ -254,7 +257,6 @@ class NCCLGroup : public GroupImpl {
~NCCLGroup() {
ncclCommDestroy(comm_);
ncclGroupEnd();
cudaStreamDestroy(stream_);
initialized_ = false;
}
@ -267,13 +269,9 @@ class NCCLGroup : public GroupImpl {
}
void all_sum(const array& input, array& output, Stream stream) override {
if (input.size() != output.size()) {
throw std::runtime_error(
"[nccl] Input and output arrays must have the same size.");
}
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
all_reduce_impl<T>(input, output, stream, dt, ncclSum);
detail::all_reduce_impl<T>(input, output, stream, dt, ncclSum);
});
}
@ -282,29 +280,45 @@ class NCCLGroup : public GroupImpl {
}
void all_gather(const array& input, array& output, Stream stream) override {
if (input.size() != output.size() / size_) {
throw std::runtime_error(
"[nccl] Input size must match output size divided by group size.");
}
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
CHECK_NCCL(ncclAllGather(
input.data<T>(),
output.data<T>(),
input.size(),
dt,
comm_,
cu::get_stream(stream).last_cuda_stream()));
});
}
void send(const array& input, int dst, Stream stream) override {
if (input.size() == 0) {
return; // Nothing to send
}
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
CHECK_NCCL(ncclSend(
input.data<T>(),
input.size(),
dt,
dst,
comm_,
cu::get_stream(stream).last_cuda_stream()));
});
}
void recv(array& output, int src, Stream stream) override {
if (output.size() == 0) {
return; // Nothing to receive
}
detail::dispatch_dtype(output, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
CHECK_NCCL(ncclRecv(
output.data<T>(),
output.size(),
dt,
src,
comm_,
cu::get_stream(stream).last_cuda_stream()));
});
}
void all_max(const array& input, array& output, Stream stream) override {
if (input.size() != output.size()) {
throw std::runtime_error(
"[nccl] Input and output arrays must have the same size.");
}
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
all_reduce_impl<T>(input, output, stream, dt, ncclMax);
@ -312,10 +326,6 @@ class NCCLGroup : public GroupImpl {
}
void all_min(const array& input, array& output, Stream stream) override {
if (input.size() != output.size()) {
throw std::runtime_error(
"[nccl] Input and output arrays must have the same size.");
}
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
all_reduce_impl<T>(input, output, stream, dt, ncclMin);
@ -329,7 +339,6 @@ class NCCLGroup : public GroupImpl {
Stream stream,
ncclDataType_t dt,
ncclRedOp_t op) {
CHECK_NCCL(ncclAllReduce(
input.data<T>(),
output.data<T>(),
@ -337,15 +346,13 @@ class NCCLGroup : public GroupImpl {
dt,
op,
comm_,
stream_));
cudaStreamSynchronize(stream_);
cu::get_stream(stream).last_cuda_stream()));
}
int rank_, size_;
std::string initMethod_;
ncclUniqueId uniqueId_;
ncclComm_t comm_;
cudaStream_t stream_;
bool initialized_ = false;
};