Use last cuda stream instead of new one

This commit is contained in:
Anastasiia Filippova 2025-06-20 16:07:41 +02:00
parent e6ae350999
commit 043c37cccd

View File

@ -80,7 +80,7 @@ inline void bootstrap_unique_id(
int rank, int rank,
int size, int size,
const std::string& initMethod) { const std::string& initMethod) {
// Parse the init method to extract the host and port
if (initMethod.rfind("tcp://", 0) != 0) if (initMethod.rfind("tcp://", 0) != 0)
throw; throw;
auto hostport = initMethod.substr(6); auto hostport = initMethod.substr(6);
@ -89,8 +89,10 @@ inline void bootstrap_unique_id(
int port = std::stoi(hostport.substr(colon + 1)); int port = std::stoi(hostport.substr(colon + 1));
if (rank == 0) { if (rank == 0) {
// create a unique id on the rank 0
CHECK_NCCL(ncclGetUniqueId(&id)); CHECK_NCCL(ncclGetUniqueId(&id));
// create a socket to send the unique id to all other ranks
int sock = socket(AF_INET, SOCK_STREAM, 0); int sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) { if (sock < 0) {
@ -105,6 +107,9 @@ inline void bootstrap_unique_id(
serv.sin_port = htons(port); serv.sin_port = htons(port);
int reuse = 1; int reuse = 1;
// Without this, if rank-0 crashes or restarts process quickly,
// the OS might refuse to let binding to the same port, so reuse
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) { if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) {
std::ostringstream msg; std::ostringstream msg;
msg << "[nccl] setsockopt() failed: " << strerror(errno); msg << "[nccl] setsockopt() failed: " << strerror(errno);
@ -244,9 +249,7 @@ class NCCLGroup : public GroupImpl {
int ndev; int ndev;
CHECK_CUDA(cudaGetDeviceCount(&ndev)); CHECK_CUDA(cudaGetDeviceCount(&ndev));
CHECK_CUDA(cudaSetDevice(rank_ % ndev)); CHECK_CUDA(cudaSetDevice(rank_ % ndev));
CHECK_CUDA(cudaStreamCreate(&stream_)); detail::bootstrap_unique_id(uniqueId_, rank_, size_, initMethod_);
detail::bootstrapUniqueId(uniqueId_, rank_, size_, initMethod_);
CHECK_NCCL(ncclCommInitRank(&comm_, size_, uniqueId_, rank_)); CHECK_NCCL(ncclCommInitRank(&comm_, size_, uniqueId_, rank_));
initialized_ = true; initialized_ = true;
} }
@ -254,7 +257,6 @@ class NCCLGroup : public GroupImpl {
~NCCLGroup() { ~NCCLGroup() {
ncclCommDestroy(comm_); ncclCommDestroy(comm_);
ncclGroupEnd(); ncclGroupEnd();
cudaStreamDestroy(stream_);
initialized_ = false; initialized_ = false;
} }
@ -267,13 +269,9 @@ class NCCLGroup : public GroupImpl {
} }
void all_sum(const array& input, array& output, Stream stream) override { void all_sum(const array& input, array& output, Stream stream) override {
if (input.size() != output.size()) {
throw std::runtime_error(
"[nccl] Input and output arrays must have the same size.");
}
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type; using T = typename decltype(type_tag)::type;
all_reduce_impl<T>(input, output, stream, dt, ncclSum); detail::all_reduce_impl<T>(input, output, stream, dt, ncclSum);
}); });
} }
@ -282,29 +280,45 @@ class NCCLGroup : public GroupImpl {
} }
void all_gather(const array& input, array& output, Stream stream) override { void all_gather(const array& input, array& output, Stream stream) override {
if (input.size() != output.size() / size_) { detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
throw std::runtime_error( using T = typename decltype(type_tag)::type;
"[nccl] Input size must match output size divided by group size."); CHECK_NCCL(ncclAllGather(
} input.data<T>(),
output.data<T>(),
input.size(),
dt,
comm_,
cu::get_stream(stream).last_cuda_stream()));
});
} }
void send(const array& input, int dst, Stream stream) override { void send(const array& input, int dst, Stream stream) override {
if (input.size() == 0) { detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
return; // Nothing to send using T = typename decltype(type_tag)::type;
} CHECK_NCCL(ncclSend(
input.data<T>(),
input.size(),
dt,
dst,
comm_,
cu::get_stream(stream).last_cuda_stream()));
});
} }
void recv(array& output, int src, Stream stream) override { void recv(array& output, int src, Stream stream) override {
if (output.size() == 0) { detail::dispatch_dtype(output, [&](auto type_tag, ncclDataType_t dt) {
return; // Nothing to receive using T = typename decltype(type_tag)::type;
} CHECK_NCCL(ncclRecv(
output.data<T>(),
output.size(),
dt,
src,
comm_,
cu::get_stream(stream).last_cuda_stream()));
});
} }
void all_max(const array& input, array& output, Stream stream) override { void all_max(const array& input, array& output, Stream stream) override {
if (input.size() != output.size()) {
throw std::runtime_error(
"[nccl] Input and output arrays must have the same size.");
}
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type; using T = typename decltype(type_tag)::type;
all_reduce_impl<T>(input, output, stream, dt, ncclMax); all_reduce_impl<T>(input, output, stream, dt, ncclMax);
@ -312,10 +326,6 @@ class NCCLGroup : public GroupImpl {
} }
void all_min(const array& input, array& output, Stream stream) override { void all_min(const array& input, array& output, Stream stream) override {
if (input.size() != output.size()) {
throw std::runtime_error(
"[nccl] Input and output arrays must have the same size.");
}
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type; using T = typename decltype(type_tag)::type;
all_reduce_impl<T>(input, output, stream, dt, ncclMin); all_reduce_impl<T>(input, output, stream, dt, ncclMin);
@ -329,7 +339,6 @@ class NCCLGroup : public GroupImpl {
Stream stream, Stream stream,
ncclDataType_t dt, ncclDataType_t dt,
ncclRedOp_t op) { ncclRedOp_t op) {
CHECK_NCCL(ncclAllReduce( CHECK_NCCL(ncclAllReduce(
input.data<T>(), input.data<T>(),
output.data<T>(), output.data<T>(),
@ -337,15 +346,13 @@ class NCCLGroup : public GroupImpl {
dt, dt,
op, op,
comm_, comm_,
stream_)); cu::get_stream(stream).last_cuda_stream()));
cudaStreamSynchronize(stream_);
} }
int rank_, size_; int rank_, size_;
std::string initMethod_; std::string initMethod_;
ncclUniqueId uniqueId_; ncclUniqueId uniqueId_;
ncclComm_t comm_; ncclComm_t comm_;
cudaStream_t stream_;
bool initialized_ = false; bool initialized_ = false;
}; };