mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Use last cuda stream instead of new one
This commit is contained in:
parent
e6ae350999
commit
043c37cccd
@ -80,7 +80,7 @@ inline void bootstrap_unique_id(
|
||||
int rank,
|
||||
int size,
|
||||
const std::string& initMethod) {
|
||||
|
||||
// Parse the init method to extract the host and port
|
||||
if (initMethod.rfind("tcp://", 0) != 0)
|
||||
throw;
|
||||
auto hostport = initMethod.substr(6);
|
||||
@ -89,8 +89,10 @@ inline void bootstrap_unique_id(
|
||||
int port = std::stoi(hostport.substr(colon + 1));
|
||||
|
||||
if (rank == 0) {
|
||||
// create a unique id on the rank 0
|
||||
CHECK_NCCL(ncclGetUniqueId(&id));
|
||||
|
||||
// create a socket to send the unique id to all other ranks
|
||||
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
|
||||
if (sock < 0) {
|
||||
@ -105,6 +107,9 @@ inline void bootstrap_unique_id(
|
||||
serv.sin_port = htons(port);
|
||||
|
||||
int reuse = 1;
|
||||
// Without this, if rank-0 crashes or restarts process quickly,
|
||||
// the OS might refuse to let binding to the same port, so reuse
|
||||
|
||||
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[nccl] setsockopt() failed: " << strerror(errno);
|
||||
@ -244,9 +249,7 @@ class NCCLGroup : public GroupImpl {
|
||||
int ndev;
|
||||
CHECK_CUDA(cudaGetDeviceCount(&ndev));
|
||||
CHECK_CUDA(cudaSetDevice(rank_ % ndev));
|
||||
CHECK_CUDA(cudaStreamCreate(&stream_));
|
||||
|
||||
detail::bootstrapUniqueId(uniqueId_, rank_, size_, initMethod_);
|
||||
detail::bootstrap_unique_id(uniqueId_, rank_, size_, initMethod_);
|
||||
CHECK_NCCL(ncclCommInitRank(&comm_, size_, uniqueId_, rank_));
|
||||
initialized_ = true;
|
||||
}
|
||||
@ -254,7 +257,6 @@ class NCCLGroup : public GroupImpl {
|
||||
~NCCLGroup() {
|
||||
ncclCommDestroy(comm_);
|
||||
ncclGroupEnd();
|
||||
cudaStreamDestroy(stream_);
|
||||
initialized_ = false;
|
||||
}
|
||||
|
||||
@ -267,13 +269,9 @@ class NCCLGroup : public GroupImpl {
|
||||
}
|
||||
|
||||
void all_sum(const array& input, array& output, Stream stream) override {
|
||||
if (input.size() != output.size()) {
|
||||
throw std::runtime_error(
|
||||
"[nccl] Input and output arrays must have the same size.");
|
||||
}
|
||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||
using T = typename decltype(type_tag)::type;
|
||||
all_reduce_impl<T>(input, output, stream, dt, ncclSum);
|
||||
detail::all_reduce_impl<T>(input, output, stream, dt, ncclSum);
|
||||
});
|
||||
}
|
||||
|
||||
@ -282,29 +280,45 @@ class NCCLGroup : public GroupImpl {
|
||||
}
|
||||
|
||||
void all_gather(const array& input, array& output, Stream stream) override {
|
||||
if (input.size() != output.size() / size_) {
|
||||
throw std::runtime_error(
|
||||
"[nccl] Input size must match output size divided by group size.");
|
||||
}
|
||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||
using T = typename decltype(type_tag)::type;
|
||||
CHECK_NCCL(ncclAllGather(
|
||||
input.data<T>(),
|
||||
output.data<T>(),
|
||||
input.size(),
|
||||
dt,
|
||||
comm_,
|
||||
cu::get_stream(stream).last_cuda_stream()));
|
||||
});
|
||||
}
|
||||
|
||||
void send(const array& input, int dst, Stream stream) override {
|
||||
if (input.size() == 0) {
|
||||
return; // Nothing to send
|
||||
}
|
||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||
using T = typename decltype(type_tag)::type;
|
||||
CHECK_NCCL(ncclSend(
|
||||
input.data<T>(),
|
||||
input.size(),
|
||||
dt,
|
||||
dst,
|
||||
comm_,
|
||||
cu::get_stream(stream).last_cuda_stream()));
|
||||
});
|
||||
}
|
||||
|
||||
void recv(array& output, int src, Stream stream) override {
|
||||
if (output.size() == 0) {
|
||||
return; // Nothing to receive
|
||||
}
|
||||
detail::dispatch_dtype(output, [&](auto type_tag, ncclDataType_t dt) {
|
||||
using T = typename decltype(type_tag)::type;
|
||||
CHECK_NCCL(ncclRecv(
|
||||
output.data<T>(),
|
||||
output.size(),
|
||||
dt,
|
||||
src,
|
||||
comm_,
|
||||
cu::get_stream(stream).last_cuda_stream()));
|
||||
});
|
||||
}
|
||||
|
||||
void all_max(const array& input, array& output, Stream stream) override {
|
||||
if (input.size() != output.size()) {
|
||||
throw std::runtime_error(
|
||||
"[nccl] Input and output arrays must have the same size.");
|
||||
}
|
||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||
using T = typename decltype(type_tag)::type;
|
||||
all_reduce_impl<T>(input, output, stream, dt, ncclMax);
|
||||
@ -312,10 +326,6 @@ class NCCLGroup : public GroupImpl {
|
||||
}
|
||||
|
||||
void all_min(const array& input, array& output, Stream stream) override {
|
||||
if (input.size() != output.size()) {
|
||||
throw std::runtime_error(
|
||||
"[nccl] Input and output arrays must have the same size.");
|
||||
}
|
||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||
using T = typename decltype(type_tag)::type;
|
||||
all_reduce_impl<T>(input, output, stream, dt, ncclMin);
|
||||
@ -329,7 +339,6 @@ class NCCLGroup : public GroupImpl {
|
||||
Stream stream,
|
||||
ncclDataType_t dt,
|
||||
ncclRedOp_t op) {
|
||||
|
||||
CHECK_NCCL(ncclAllReduce(
|
||||
input.data<T>(),
|
||||
output.data<T>(),
|
||||
@ -337,15 +346,13 @@ class NCCLGroup : public GroupImpl {
|
||||
dt,
|
||||
op,
|
||||
comm_,
|
||||
stream_));
|
||||
cudaStreamSynchronize(stream_);
|
||||
cu::get_stream(stream).last_cuda_stream()));
|
||||
}
|
||||
|
||||
int rank_, size_;
|
||||
std::string initMethod_;
|
||||
ncclUniqueId uniqueId_;
|
||||
ncclComm_t comm_;
|
||||
cudaStream_t stream_;
|
||||
bool initialized_ = false;
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user