mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Use last cuda stream instead of new one
This commit is contained in:
parent
e6ae350999
commit
043c37cccd
@ -80,7 +80,7 @@ inline void bootstrap_unique_id(
|
|||||||
int rank,
|
int rank,
|
||||||
int size,
|
int size,
|
||||||
const std::string& initMethod) {
|
const std::string& initMethod) {
|
||||||
|
// Parse the init method to extract the host and port
|
||||||
if (initMethod.rfind("tcp://", 0) != 0)
|
if (initMethod.rfind("tcp://", 0) != 0)
|
||||||
throw;
|
throw;
|
||||||
auto hostport = initMethod.substr(6);
|
auto hostport = initMethod.substr(6);
|
||||||
@ -89,8 +89,10 @@ inline void bootstrap_unique_id(
|
|||||||
int port = std::stoi(hostport.substr(colon + 1));
|
int port = std::stoi(hostport.substr(colon + 1));
|
||||||
|
|
||||||
if (rank == 0) {
|
if (rank == 0) {
|
||||||
|
// create a unique id on the rank 0
|
||||||
CHECK_NCCL(ncclGetUniqueId(&id));
|
CHECK_NCCL(ncclGetUniqueId(&id));
|
||||||
|
|
||||||
|
// create a socket to send the unique id to all other ranks
|
||||||
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||||
|
|
||||||
if (sock < 0) {
|
if (sock < 0) {
|
||||||
@ -105,6 +107,9 @@ inline void bootstrap_unique_id(
|
|||||||
serv.sin_port = htons(port);
|
serv.sin_port = htons(port);
|
||||||
|
|
||||||
int reuse = 1;
|
int reuse = 1;
|
||||||
|
// Without this, if rank-0 crashes or restarts process quickly,
|
||||||
|
// the OS might refuse to let binding to the same port, so reuse
|
||||||
|
|
||||||
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) {
|
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[nccl] setsockopt() failed: " << strerror(errno);
|
msg << "[nccl] setsockopt() failed: " << strerror(errno);
|
||||||
@ -244,9 +249,7 @@ class NCCLGroup : public GroupImpl {
|
|||||||
int ndev;
|
int ndev;
|
||||||
CHECK_CUDA(cudaGetDeviceCount(&ndev));
|
CHECK_CUDA(cudaGetDeviceCount(&ndev));
|
||||||
CHECK_CUDA(cudaSetDevice(rank_ % ndev));
|
CHECK_CUDA(cudaSetDevice(rank_ % ndev));
|
||||||
CHECK_CUDA(cudaStreamCreate(&stream_));
|
detail::bootstrap_unique_id(uniqueId_, rank_, size_, initMethod_);
|
||||||
|
|
||||||
detail::bootstrapUniqueId(uniqueId_, rank_, size_, initMethod_);
|
|
||||||
CHECK_NCCL(ncclCommInitRank(&comm_, size_, uniqueId_, rank_));
|
CHECK_NCCL(ncclCommInitRank(&comm_, size_, uniqueId_, rank_));
|
||||||
initialized_ = true;
|
initialized_ = true;
|
||||||
}
|
}
|
||||||
@ -254,7 +257,6 @@ class NCCLGroup : public GroupImpl {
|
|||||||
~NCCLGroup() {
|
~NCCLGroup() {
|
||||||
ncclCommDestroy(comm_);
|
ncclCommDestroy(comm_);
|
||||||
ncclGroupEnd();
|
ncclGroupEnd();
|
||||||
cudaStreamDestroy(stream_);
|
|
||||||
initialized_ = false;
|
initialized_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -267,13 +269,9 @@ class NCCLGroup : public GroupImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void all_sum(const array& input, array& output, Stream stream) override {
|
void all_sum(const array& input, array& output, Stream stream) override {
|
||||||
if (input.size() != output.size()) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[nccl] Input and output arrays must have the same size.");
|
|
||||||
}
|
|
||||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
using T = typename decltype(type_tag)::type;
|
using T = typename decltype(type_tag)::type;
|
||||||
all_reduce_impl<T>(input, output, stream, dt, ncclSum);
|
detail::all_reduce_impl<T>(input, output, stream, dt, ncclSum);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -282,29 +280,45 @@ class NCCLGroup : public GroupImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void all_gather(const array& input, array& output, Stream stream) override {
|
void all_gather(const array& input, array& output, Stream stream) override {
|
||||||
if (input.size() != output.size() / size_) {
|
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
throw std::runtime_error(
|
using T = typename decltype(type_tag)::type;
|
||||||
"[nccl] Input size must match output size divided by group size.");
|
CHECK_NCCL(ncclAllGather(
|
||||||
}
|
input.data<T>(),
|
||||||
|
output.data<T>(),
|
||||||
|
input.size(),
|
||||||
|
dt,
|
||||||
|
comm_,
|
||||||
|
cu::get_stream(stream).last_cuda_stream()));
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void send(const array& input, int dst, Stream stream) override {
|
void send(const array& input, int dst, Stream stream) override {
|
||||||
if (input.size() == 0) {
|
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
return; // Nothing to send
|
using T = typename decltype(type_tag)::type;
|
||||||
}
|
CHECK_NCCL(ncclSend(
|
||||||
|
input.data<T>(),
|
||||||
|
input.size(),
|
||||||
|
dt,
|
||||||
|
dst,
|
||||||
|
comm_,
|
||||||
|
cu::get_stream(stream).last_cuda_stream()));
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void recv(array& output, int src, Stream stream) override {
|
void recv(array& output, int src, Stream stream) override {
|
||||||
if (output.size() == 0) {
|
detail::dispatch_dtype(output, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
return; // Nothing to receive
|
using T = typename decltype(type_tag)::type;
|
||||||
}
|
CHECK_NCCL(ncclRecv(
|
||||||
|
output.data<T>(),
|
||||||
|
output.size(),
|
||||||
|
dt,
|
||||||
|
src,
|
||||||
|
comm_,
|
||||||
|
cu::get_stream(stream).last_cuda_stream()));
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void all_max(const array& input, array& output, Stream stream) override {
|
void all_max(const array& input, array& output, Stream stream) override {
|
||||||
if (input.size() != output.size()) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[nccl] Input and output arrays must have the same size.");
|
|
||||||
}
|
|
||||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
using T = typename decltype(type_tag)::type;
|
using T = typename decltype(type_tag)::type;
|
||||||
all_reduce_impl<T>(input, output, stream, dt, ncclMax);
|
all_reduce_impl<T>(input, output, stream, dt, ncclMax);
|
||||||
@ -312,10 +326,6 @@ class NCCLGroup : public GroupImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void all_min(const array& input, array& output, Stream stream) override {
|
void all_min(const array& input, array& output, Stream stream) override {
|
||||||
if (input.size() != output.size()) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[nccl] Input and output arrays must have the same size.");
|
|
||||||
}
|
|
||||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
using T = typename decltype(type_tag)::type;
|
using T = typename decltype(type_tag)::type;
|
||||||
all_reduce_impl<T>(input, output, stream, dt, ncclMin);
|
all_reduce_impl<T>(input, output, stream, dt, ncclMin);
|
||||||
@ -329,7 +339,6 @@ class NCCLGroup : public GroupImpl {
|
|||||||
Stream stream,
|
Stream stream,
|
||||||
ncclDataType_t dt,
|
ncclDataType_t dt,
|
||||||
ncclRedOp_t op) {
|
ncclRedOp_t op) {
|
||||||
|
|
||||||
CHECK_NCCL(ncclAllReduce(
|
CHECK_NCCL(ncclAllReduce(
|
||||||
input.data<T>(),
|
input.data<T>(),
|
||||||
output.data<T>(),
|
output.data<T>(),
|
||||||
@ -337,15 +346,13 @@ class NCCLGroup : public GroupImpl {
|
|||||||
dt,
|
dt,
|
||||||
op,
|
op,
|
||||||
comm_,
|
comm_,
|
||||||
stream_));
|
cu::get_stream(stream).last_cuda_stream()));
|
||||||
cudaStreamSynchronize(stream_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int rank_, size_;
|
int rank_, size_;
|
||||||
std::string initMethod_;
|
std::string initMethod_;
|
||||||
ncclUniqueId uniqueId_;
|
ncclUniqueId uniqueId_;
|
||||||
ncclComm_t comm_;
|
ncclComm_t comm_;
|
||||||
cudaStream_t stream_;
|
|
||||||
bool initialized_ = false;
|
bool initialized_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user