From 043c37cccd88d5043a4a68634e56dd496e93640c Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Fri, 20 Jun 2025 16:07:41 +0200 Subject: [PATCH] Use last cuda stream instead of new one --- mlx/distributed/nccl/nccl.cpp | 71 +++++++++++++++++++---------------- 1 file changed, 39 insertions(+), 32 deletions(-) diff --git a/mlx/distributed/nccl/nccl.cpp b/mlx/distributed/nccl/nccl.cpp index f6fa28ad8..02b1fc20c 100644 --- a/mlx/distributed/nccl/nccl.cpp +++ b/mlx/distributed/nccl/nccl.cpp @@ -80,7 +80,7 @@ inline void bootstrap_unique_id( int rank, int size, const std::string& initMethod) { - + // Parse the init method to extract the host and port if (initMethod.rfind("tcp://", 0) != 0) throw; auto hostport = initMethod.substr(6); @@ -89,8 +89,10 @@ inline void bootstrap_unique_id( int port = std::stoi(hostport.substr(colon + 1)); if (rank == 0) { + // create a unique id on the rank 0 CHECK_NCCL(ncclGetUniqueId(&id)); + // create a socket to send the unique id to all other ranks int sock = socket(AF_INET, SOCK_STREAM, 0); if (sock < 0) { @@ -105,6 +107,9 @@ inline void bootstrap_unique_id( serv.sin_port = htons(port); int reuse = 1; + // Without this, if rank-0 crashes or restarts process quickly, + // the OS might refuse to let binding to the same port, so reuse + if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) { std::ostringstream msg; msg << "[nccl] setsockopt() failed: " << strerror(errno); @@ -244,9 +249,7 @@ class NCCLGroup : public GroupImpl { int ndev; CHECK_CUDA(cudaGetDeviceCount(&ndev)); CHECK_CUDA(cudaSetDevice(rank_ % ndev)); - CHECK_CUDA(cudaStreamCreate(&stream_)); - - detail::bootstrapUniqueId(uniqueId_, rank_, size_, initMethod_); + detail::bootstrap_unique_id(uniqueId_, rank_, size_, initMethod_); CHECK_NCCL(ncclCommInitRank(&comm_, size_, uniqueId_, rank_)); initialized_ = true; } @@ -254,7 +257,6 @@ class NCCLGroup : public GroupImpl { ~NCCLGroup() { ncclCommDestroy(comm_); ncclGroupEnd(); - cudaStreamDestroy(stream_); initialized_ = false; } @@ -267,13 +269,9 @@ class NCCLGroup : public GroupImpl { } void all_sum(const array& input, array& output, Stream stream) override { - if (input.size() != output.size()) { - throw std::runtime_error( - "[nccl] Input and output arrays must have the same size."); - } detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { using T = typename decltype(type_tag)::type; - all_reduce_impl(input, output, stream, dt, ncclSum); + detail::all_reduce_impl(input, output, stream, dt, ncclSum); }); } @@ -282,29 +280,45 @@ class NCCLGroup : public GroupImpl { } void all_gather(const array& input, array& output, Stream stream) override { - if (input.size() != output.size() / size_) { - throw std::runtime_error( - "[nccl] Input size must match output size divided by group size."); - } + detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { + using T = typename decltype(type_tag)::type; + CHECK_NCCL(ncclAllGather( + input.data(), + output.data(), + input.size(), + dt, + comm_, + cu::get_stream(stream).last_cuda_stream())); + }); } void send(const array& input, int dst, Stream stream) override { - if (input.size() == 0) { - return; // Nothing to send - } + detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { + using T = typename decltype(type_tag)::type; + CHECK_NCCL(ncclSend( + input.data(), + input.size(), + dt, + dst, + comm_, + cu::get_stream(stream).last_cuda_stream())); + }); } void recv(array& output, int src, Stream stream) override { - if (output.size() == 0) { - return; // Nothing to receive - } + detail::dispatch_dtype(output, [&](auto type_tag, ncclDataType_t dt) { + using T = typename decltype(type_tag)::type; + CHECK_NCCL(ncclRecv( + output.data(), + output.size(), + dt, + src, + comm_, + cu::get_stream(stream).last_cuda_stream())); + }); } void all_max(const array& input, array& output, Stream stream) override { - if (input.size() != output.size()) { - throw std::runtime_error( - "[nccl] Input and output arrays must have the same size."); - } detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { using T = typename decltype(type_tag)::type; all_reduce_impl(input, output, stream, dt, ncclMax); @@ -312,10 +326,6 @@ class NCCLGroup : public GroupImpl { } void all_min(const array& input, array& output, Stream stream) override { - if (input.size() != output.size()) { - throw std::runtime_error( - "[nccl] Input and output arrays must have the same size."); - } detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { using T = typename decltype(type_tag)::type; all_reduce_impl(input, output, stream, dt, ncclMin); @@ -329,7 +339,6 @@ class NCCLGroup : public GroupImpl { Stream stream, ncclDataType_t dt, ncclRedOp_t op) { - CHECK_NCCL(ncclAllReduce( input.data(), output.data(), @@ -337,15 +346,13 @@ class NCCLGroup : public GroupImpl { dt, op, comm_, - stream_)); - cudaStreamSynchronize(stream_); + cu::get_stream(stream).last_cuda_stream())); } int rank_, size_; std::string initMethod_; ncclUniqueId uniqueId_; ncclComm_t comm_; - cudaStream_t stream_; bool initialized_ = false; };