From f15a127900a719646255c7bfe1c4a910e42c3f1d Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Mon, 16 Jun 2025 14:28:53 +0200 Subject: [PATCH 1/7] nccl backend (all reduce + init) --- cmake/FindNCCL.cmake | 64 +++++ mlx/backend/cuda/primitives.cu | 24 +- mlx/distributed/CMakeLists.txt | 1 + mlx/distributed/distributed.cpp | 4 + mlx/distributed/nccl/CMakeLists.txt | 8 + mlx/distributed/nccl/nccl.cpp | 360 ++++++++++++++++++++++++++++ mlx/distributed/nccl/nccl.h | 12 + mlx/distributed/nccl/no_nccl.cpp | 20 ++ mlx/distributed/ops.cpp | 3 +- 9 files changed, 493 insertions(+), 3 deletions(-) create mode 100644 cmake/FindNCCL.cmake create mode 100644 mlx/distributed/nccl/CMakeLists.txt create mode 100644 mlx/distributed/nccl/nccl.cpp create mode 100644 mlx/distributed/nccl/nccl.h create mode 100644 mlx/distributed/nccl/no_nccl.cpp diff --git a/cmake/FindNCCL.cmake b/cmake/FindNCCL.cmake new file mode 100644 index 000000000..7f8791476 --- /dev/null +++ b/cmake/FindNCCL.cmake @@ -0,0 +1,64 @@ +# Find the nccl libraries +# +# The following variables are optionally searched for defaults NCCL_ROOT_DIR: +# Base directory where all NCCL components are found NCCL_INCLUDE_DIR: Directory +# where NCCL header is found NCCL_LIB_DIR: Directory where NCCL library is found +# +# The following are set after configuration is done: NCCL_FOUND +# NCCL_INCLUDE_DIRS NCCL_LIBRARIES +# +# The path hints include CUDA_TOOLKIT_ROOT_DIR seeing as some folks install NCCL +# in the same location as the CUDA toolkit. See +# https://github.com/caffe2/caffe2/issues/1601 + +set(NCCL_ROOT_DIR + $ENV{NCCL_ROOT_DIR} + CACHE PATH "Folder contains NVIDIA NCCL") + +find_path( + NCCL_INCLUDE_DIRS + NAMES nccl.h + HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include + ${CUDA_TOOLKIT_ROOT_DIR}/include) + +if($ENV{USE_STATIC_NCCL}) + message( + STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library") + set(NCCL_LIBNAME "libnccl_static.a") +else() + set(NCCL_LIBNAME "nccl") +endif() + +find_library( + NCCL_LIBRARIES + NAMES ${NCCL_LIBNAME} + HINTS ${NCCL_LIB_DIR} + ${NCCL_ROOT_DIR} + ${NCCL_ROOT_DIR}/lib + ${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu + ${NCCL_ROOT_DIR}/lib64 + ${CUDA_TOOLKIT_ROOT_DIR}/lib + ${CUDA_TOOLKIT_ROOT_DIR}/lib64) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS + NCCL_LIBRARIES) + +if(NCCL_FOUND) + set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h") + message( + STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}") + file( + STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED + REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$" + LIMIT_COUNT 1) + if(NCCL_MAJOR_VERSION_DEFINED) + string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" "" + NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED}) + message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}") + endif() + message( + STATUS + "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})") + mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES) +endif() diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index c2362bea2..d143f9fa5 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -54,6 +54,28 @@ bool fast::ScaledDotProductAttention::use_fallback( return true; } +namespace distributed { +void AllReduce::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + // Here I assume for now that in is donatable and contiguous. + // TODO + + auto& input = inputs[0]; + auto& output = outputs[0]; + + output.copy_shared_buffer(input); + auto& s = stream(); + switch (reduce_type_) { + case Sum: + distributed::detail::all_sum(group(), input, output, s); + break; + default: + throw std::runtime_error("Only all reduce sum is supported for now"); + } +} +} // namespace distributed + #define NO_GPU_MULTI(func) \ void func::eval_gpu( \ const std::vector& inputs, std::vector& outputs) { \ @@ -100,7 +122,7 @@ NO_GPU_MULTI(CustomKernel) } // namespace fast namespace distributed { -NO_GPU_MULTI(AllReduce) +// NO_GPU_MULTI(AllReduce) NO_GPU_MULTI(AllGather) NO_GPU_MULTI(Send) NO_GPU_MULTI(Recv) diff --git a/mlx/distributed/CMakeLists.txt b/mlx/distributed/CMakeLists.txt index 8e16bd40d..b7762f6a7 100644 --- a/mlx/distributed/CMakeLists.txt +++ b/mlx/distributed/CMakeLists.txt @@ -6,3 +6,4 @@ target_sources( add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl) diff --git a/mlx/distributed/distributed.cpp b/mlx/distributed/distributed.cpp index cc01e6090..b299b1b9d 100644 --- a/mlx/distributed/distributed.cpp +++ b/mlx/distributed/distributed.cpp @@ -2,9 +2,11 @@ #include +#include #include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/mpi/mpi.h" +#include "mlx/distributed/nccl/nccl.h" #include "mlx/distributed/ring/ring.h" namespace mlx::core::distributed { @@ -111,6 +113,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) { group = mpi::init(strict); } else if (bk == "ring") { group = ring::init(strict); + } else if (bk == "nccl") { + group = nccl::init(strict); } else if (bk == "any") { group = ring::init(false); bk_ = "ring"; diff --git a/mlx/distributed/nccl/CMakeLists.txt b/mlx/distributed/nccl/CMakeLists.txt new file mode 100644 index 000000000..2f764c6ac --- /dev/null +++ b/mlx/distributed/nccl/CMakeLists.txt @@ -0,0 +1,8 @@ +if(MLX_BUILD_CUDA) + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nccl.cpp) + find_package(NCCL REQUIRED) + target_link_libraries(mlx PRIVATE ${NCCL_LIBRARIES}) + target_include_directories(mlx PRIVATE ${NCCL_INCLUDE_DIRS}) +else() + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_nccl.cpp) +endif() diff --git a/mlx/distributed/nccl/nccl.cpp b/mlx/distributed/nccl/nccl.cpp new file mode 100644 index 000000000..7cb37a05a --- /dev/null +++ b/mlx/distributed/nccl/nccl.cpp @@ -0,0 +1,360 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mlx/backend/cuda/device.h" +#include "mlx/distributed/distributed.h" +#include "mlx/distributed/distributed_impl.h" + +namespace mlx::core::distributed::nccl { + +#define CHECK_CUDA(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + fprintf( \ + stderr, \ + "CUDA error %s:%d '%s'\n", \ + __FILE__, \ + __LINE__, \ + cudaGetErrorString(e)); \ + exit(1); \ + } \ + } while (0) + +#define CHECK_NCCL(cmd) \ + do { \ + ncclResult_t r = cmd; \ + if (r != ncclSuccess) { \ + fprintf( \ + stderr, \ + "NCCL error %s:%d '%s'\n", \ + __FILE__, \ + __LINE__, \ + ncclGetErrorString(r)); \ + exit(1); \ + } \ + } while (0) + +namespace detail { + +inline void sendAll(int sock, const void* buf, size_t len) { + const char* ptr = reinterpret_cast(buf); + while (len > 0) { + ssize_t sent = send(sock, ptr, len, 0); + if (sent <= 0) { + perror("send"); + exit(1); + } + ptr += sent; + len -= sent; + } +} + +inline void recvAll(int sock, void* buf, size_t len) { + char* ptr = reinterpret_cast(buf); + while (len > 0) { + ssize_t rec = recv(sock, ptr, len, 0); + if (rec <= 0) { + perror("recv"); + exit(1); + } + ptr += rec; + len -= rec; + } +} + +inline void bootstrapUniqueId( + ncclUniqueId& id, + int rank, + int size, + const std::string& initMethod) { + // Parse the init method to extract the host and port + if (initMethod.rfind("tcp://", 0) != 0) + throw; + auto hostport = initMethod.substr(6); + auto colon = hostport.find(':'); + std::string host = hostport.substr(0, colon); + int port = std::stoi(hostport.substr(colon + 1)); + + if (rank == 0) { + // create a unique id on the rank 0 + CHECK_NCCL(ncclGetUniqueId(&id)); + + // create a socket to send the unique id to all other ranks + int sock = socket(AF_INET, SOCK_STREAM, 0); + + if (sock < 0) { + std::ostringstream msg; + msg << "[nccl] Couldn't create socket (error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + + sockaddr_in serv = {}; + serv.sin_family = AF_INET; + serv.sin_addr.s_addr = htonl(INADDR_ANY); + serv.sin_port = htons(port); + + int reuse = 1; + // Without this, if I crash or restart your rank-0 process quickly, + // the OS might refuse to let you bind to the same port, so reuse + if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) { + std::ostringstream msg; + msg << "[nccl] setsockopt() failed: " << strerror(errno); + throw std::runtime_error(msg.str()); + } + + if (bind(sock, reinterpret_cast(&serv), sizeof(serv)) < 0) { + std::ostringstream msg; + msg << "[nccl] bind() failed: " << strerror(errno); + throw std::runtime_error(msg.str()); + } + if (listen(sock, size - 1) < 0) { + std::ostringstream msg; + msg << "[nccl] listen() failed: " << strerror(errno); + throw std::runtime_error(msg.str()); + } + + for (int peer = 1; peer < size; ++peer) { + int conn = accept(sock, nullptr, nullptr); + if (conn < 0) { + std::ostringstream msg; + msg << "[nccl] accept() failed: " << strerror(errno); + throw std::runtime_error(msg.str()); + } + sendAll(conn, &id, sizeof(id)); + close(conn); + } + close(sock); + + } else { + // Here just wanted to make show that rank 0 has enough time to bind + // so we will retry to connect until max attempts + + int sock = socket(AF_INET, SOCK_STREAM, 0); + if (sock < 0) { + std::ostringstream msg; + msg << "[nccl] socket() failed: " << strerror(errno); + throw std::runtime_error(msg.str()); + } + + hostent* he = gethostbyname(host.c_str()); + if (!he) { + throw std::runtime_error("[nccl] lookup failed for host: " + host); + } + sockaddr_in serv = {}; + serv.sin_family = AF_INET; + memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length); + serv.sin_port = htons(port); + + const int max_retries = 30; + int attempt = 0; + bool connected = false; + + for (attempt = 0; attempt < max_retries; ++attempt) { + if (connect(sock, reinterpret_cast(&serv), sizeof(serv)) == + 0) { + connected = true; + std::cout << "[Rank " << rank << "] Connected successfully on attempt " + << attempt + 1 << std::endl; + break; + } + if (errno != ECONNREFUSED) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + } + + if (!connected) { + std::ostringstream msg; + msg << "[Rank " << rank << "] connect() failed after " << attempt + << " retries: " << strerror(errno); + close(sock); + throw std::runtime_error(msg.str()); + } + recvAll(sock, &id, sizeof(id)); + close(sock); + } +} + +inline ncclDataType_t datatype(const array& arr) { + switch (arr.dtype()) { + case bool_: + throw std::invalid_argument("[nccl] Boolean arrays not supported"); + case int8: + return ncclChar; + case uint8: + return ncclUint8; + case int32: + return ncclInt; + case uint32: + return ncclUint32; + case int64: + return ncclInt64; + case uint64: + return ncclUint64; + case float16: + return ncclHalf; + case float32: + return ncclFloat; + case float64: + return ncclDouble; + case bfloat16: + return ncclBfloat16; + default: + throw std::invalid_argument("[nccl] Unknown or unsupported dtype"); + } +} + +} // namespace detail + +using GroupImpl = mlx::core::distributed::detail::GroupImpl; +// init communication in the constructor (?) +class NCCLGroup : public GroupImpl { + public: + NCCLGroup(int worldRank, int worldSize, const std::string initMethod) + : rank_(worldRank), + size_(worldSize), + comm_(nullptr), + initMethod_(initMethod) { + if (initialized_) + return; + int ndev; + CHECK_CUDA(cudaGetDeviceCount(&ndev)); + CHECK_CUDA(cudaSetDevice(rank_ % ndev)); + CHECK_CUDA(cudaStreamCreate(&stream_)); + + detail::bootstrapUniqueId(uniqueId_, rank_, size_, initMethod_); + CHECK_NCCL(ncclCommInitRank(&comm_, size_, uniqueId_, rank_)); + initialized_ = true; + } + + ~NCCLGroup() { + ncclCommDestroy(comm_); + ncclGroupEnd(); + cudaStreamDestroy(stream_); + initialized_ = false; + } + + int rank() override { + return rank_; + } + + int size() override { + return size_; + } + + void all_sum(const array& input, array& output, Stream stream) override { + if (input.size() != output.size()) { + throw std::runtime_error( + "[nccl] Input and output arrays must have the same size."); + } + all_reduce_impl(input, output, stream, ncclSum); + } + + virtual std::shared_ptr split(int color, int key = -1) override { + throw std::runtime_error("[nccl] Group split not supported."); + } + + void all_gather(const array& input, array& output, Stream stream) override { + if (input.size() != output.size() / size_) { + throw std::runtime_error( + "[nccl] Input size must match output size divided by group size."); + } + } + + void send(const array& input, int dst, Stream stream) override { + if (input.size() == 0) { + return; // Nothing to send + } + } + + void recv(array& output, int src, Stream stream) override { + if (output.size() == 0) { + return; // Nothing to receive + } + } + + void all_max(const array& input, array& output, Stream stream) override { + if (input.size() != output.size()) { + throw std::runtime_error( + "[nccl] Input and output arrays must have the same size."); + } + all_reduce_impl(input, output, stream, ncclMax); + } + + void all_min(const array& input, array& output, Stream stream) override { + if (input.size() != output.size()) { + throw std::runtime_error( + "[nccl] Input and output arrays must have the same size."); + } + all_reduce_impl(input, output, stream, ncclMin); + } + + template + void all_reduce_impl( + const array& input, + array& output, + Stream stream, + ncclRedOp_t op) { + ncclDataType_t dt = detail::datatype(input); + + CHECK_NCCL(ncclAllReduce( + input.data(), + output.data(), + input.size(), + dt, + op, + comm_, + stream_)); + } + + int rank_, size_; + std::string initMethod_; + ncclUniqueId uniqueId_; + ncclComm_t comm_; + cudaStream_t stream_; + bool initialized_ = false; +}; + +bool is_available() { + return true; +} + +namespace detail { +static std::string get_env_var_or_throw(const char* env_var_name) { + const char* value = std::getenv(env_var_name); + if (value == nullptr) { + std::ostringstream msg; + msg << "[nccl] Required environment variable '" << env_var_name + << "' is not set. " + << "Please set it before initializing the distributed backend."; + throw std::runtime_error(msg.str()); + } + return std::string(value); +} +} // namespace detail + +std::shared_ptr init(bool strict /* = false */) { + std::string host = detail::get_env_var_or_throw("NCCL_HOST_IP"); + std::string port = detail::get_env_var_or_throw("NCCL_PORT"); + std::string rank_str = detail::get_env_var_or_throw("MLX_RANK"); + std::string n_nodes_str = detail::get_env_var_or_throw("MLX_WORLD_SIZE"); + + int rank = std::stoi(rank_str); + int n_nodes = std::stoi(n_nodes_str); + std::string init_method = "tcp://" + host + ":" + port; + + return std::make_shared(rank, n_nodes, init_method); +} +} // namespace mlx::core::distributed::nccl diff --git a/mlx/distributed/nccl/nccl.h b/mlx/distributed/nccl/nccl.h new file mode 100644 index 000000000..5370d2daf --- /dev/null +++ b/mlx/distributed/nccl/nccl.h @@ -0,0 +1,12 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/distributed/distributed.h" + +namespace mlx::core::distributed::nccl { + +using GroupImpl = mlx::core::distributed::detail::GroupImpl; + +bool is_available(); +std::shared_ptr init(bool strict = false); + +} // namespace mlx::core::distributed::nccl diff --git a/mlx/distributed/nccl/no_nccl.cpp b/mlx/distributed/nccl/no_nccl.cpp new file mode 100644 index 000000000..1be256a11 --- /dev/null +++ b/mlx/distributed/nccl/no_nccl.cpp @@ -0,0 +1,20 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/distributed/nccl/nccl.h" + +namespace mlx::core::distributed::nccl { + +using GroupImpl = mlx::core::distributed::detail::GroupImpl; + +bool is_available() { + return false; +} + +std::shared_ptr init(bool strict /* = false */) { + if (strict) { + throw std::runtime_error("Cannot initialize nccl distributed backend."); + } + return nullptr; +} + +} // namespace mlx::core::distributed::nccl diff --git a/mlx/distributed/ops.cpp b/mlx/distributed/ops.cpp index 0a5114805..9c251a944 100644 --- a/mlx/distributed/ops.cpp +++ b/mlx/distributed/ops.cpp @@ -31,8 +31,7 @@ array all_sum( return array( x.shape(), x.dtype(), - std::make_shared( - to_stream(s, Device::cpu), group, AllReduce::Sum), + std::make_shared(to_stream(s), group, AllReduce::Sum), {x}); } From e9fbdd20fb8c04363910e0a41c3999a005067a49 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Mon, 16 Jun 2025 18:35:49 +0200 Subject: [PATCH 2/7] Helper function to parse types --- mlx/distributed/nccl/nccl.cpp | 62 +++++++++++++++++++++++++---------- 1 file changed, 44 insertions(+), 18 deletions(-) diff --git a/mlx/distributed/nccl/nccl.cpp b/mlx/distributed/nccl/nccl.cpp index 7cb37a05a..8427ecf01 100644 --- a/mlx/distributed/nccl/nccl.cpp +++ b/mlx/distributed/nccl/nccl.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include "mlx/backend/cuda/device.h" #include "mlx/distributed/distributed.h" @@ -187,30 +188,46 @@ inline void bootstrapUniqueId( } } -inline ncclDataType_t datatype(const array& arr) { +template +struct type_identity { + using type = T; +}; + +template +void dispatch_dtype(const array& arr, F&& f) { switch (arr.dtype()) { case bool_: throw std::invalid_argument("[nccl] Boolean arrays not supported"); case int8: - return ncclChar; + f(type_identity{}, ncclChar); + break; case uint8: - return ncclUint8; + f(type_identity{}, ncclUint8); + break; case int32: - return ncclInt; + f(type_identity{}, ncclInt); + break; case uint32: - return ncclUint32; + f(type_identity{}, ncclUint32); + break; case int64: - return ncclInt64; + f(type_identity{}, ncclInt64); + break; case uint64: - return ncclUint64; + f(type_identity{}, ncclUint64); + break; case float16: - return ncclHalf; - case float32: - return ncclFloat; - case float64: - return ncclDouble; + f(type_identity{}, ncclHalf); + break; case bfloat16: - return ncclBfloat16; + f(type_identity{}, ncclBfloat16); + break; + case float32: + f(type_identity{}, ncclFloat); + break; + case float64: + f(type_identity{}, ncclDouble); + break; default: throw std::invalid_argument("[nccl] Unknown or unsupported dtype"); } @@ -259,7 +276,10 @@ class NCCLGroup : public GroupImpl { throw std::runtime_error( "[nccl] Input and output arrays must have the same size."); } - all_reduce_impl(input, output, stream, ncclSum); + detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { + using T = typename decltype(type_tag)::type; + all_reduce_impl(input, output, stream, dt, ncclSum); + }); } virtual std::shared_ptr split(int color, int key = -1) override { @@ -290,7 +310,10 @@ class NCCLGroup : public GroupImpl { throw std::runtime_error( "[nccl] Input and output arrays must have the same size."); } - all_reduce_impl(input, output, stream, ncclMax); + detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { + using T = typename decltype(type_tag)::type; + all_reduce_impl(input, output, stream, dt, ncclMax); + }); } void all_min(const array& input, array& output, Stream stream) override { @@ -298,7 +321,10 @@ class NCCLGroup : public GroupImpl { throw std::runtime_error( "[nccl] Input and output arrays must have the same size."); } - all_reduce_impl(input, output, stream, ncclMin); + detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { + using T = typename decltype(type_tag)::type; + all_reduce_impl(input, output, stream, dt, ncclMin); + }); } template @@ -306,9 +332,8 @@ class NCCLGroup : public GroupImpl { const array& input, array& output, Stream stream, + ncclDataType_t dt, ncclRedOp_t op) { - ncclDataType_t dt = detail::datatype(input); - CHECK_NCCL(ncclAllReduce( input.data(), output.data(), @@ -317,6 +342,7 @@ class NCCLGroup : public GroupImpl { op, comm_, stream_)); + cudaStreamSynchronize(stream_); } int rank_, size_; From 71a47bc10d6153f3834cca821cb67b0e66037fa2 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Mon, 16 Jun 2025 19:08:38 +0200 Subject: [PATCH 3/7] Deleted useless import --- mlx/distributed/distributed.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/mlx/distributed/distributed.cpp b/mlx/distributed/distributed.cpp index b299b1b9d..f791ee29e 100644 --- a/mlx/distributed/distributed.cpp +++ b/mlx/distributed/distributed.cpp @@ -2,7 +2,6 @@ #include -#include #include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/mpi/mpi.h" From 70f2baf39f17efbfc0b698553bf59bc9b8d7146d Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Mon, 16 Jun 2025 19:11:28 +0200 Subject: [PATCH 4/7] Removed commented nogpu for all_reduce --- mlx/backend/cuda/primitives.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index d143f9fa5..137dd3b3c 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -122,7 +122,6 @@ NO_GPU_MULTI(CustomKernel) } // namespace fast namespace distributed { -// NO_GPU_MULTI(AllReduce) NO_GPU_MULTI(AllGather) NO_GPU_MULTI(Send) NO_GPU_MULTI(Recv) From e6ae3509999676632f4a9172c904a0f3df7fd7a4 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Tue, 17 Jun 2025 08:55:02 +0200 Subject: [PATCH 5/7] Deleted comments, renamed the function --- mlx/distributed/nccl/nccl.cpp | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/mlx/distributed/nccl/nccl.cpp b/mlx/distributed/nccl/nccl.cpp index 8427ecf01..f6fa28ad8 100644 --- a/mlx/distributed/nccl/nccl.cpp +++ b/mlx/distributed/nccl/nccl.cpp @@ -75,12 +75,12 @@ inline void recvAll(int sock, void* buf, size_t len) { } } -inline void bootstrapUniqueId( +inline void bootstrap_unique_id( ncclUniqueId& id, int rank, int size, const std::string& initMethod) { - // Parse the init method to extract the host and port + if (initMethod.rfind("tcp://", 0) != 0) throw; auto hostport = initMethod.substr(6); @@ -89,10 +89,8 @@ inline void bootstrapUniqueId( int port = std::stoi(hostport.substr(colon + 1)); if (rank == 0) { - // create a unique id on the rank 0 CHECK_NCCL(ncclGetUniqueId(&id)); - // create a socket to send the unique id to all other ranks int sock = socket(AF_INET, SOCK_STREAM, 0); if (sock < 0) { @@ -107,8 +105,6 @@ inline void bootstrapUniqueId( serv.sin_port = htons(port); int reuse = 1; - // Without this, if I crash or restart your rank-0 process quickly, - // the OS might refuse to let you bind to the same port, so reuse if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) { std::ostringstream msg; msg << "[nccl] setsockopt() failed: " << strerror(errno); @@ -236,7 +232,6 @@ void dispatch_dtype(const array& arr, F&& f) { } // namespace detail using GroupImpl = mlx::core::distributed::detail::GroupImpl; -// init communication in the constructor (?) class NCCLGroup : public GroupImpl { public: NCCLGroup(int worldRank, int worldSize, const std::string initMethod) @@ -334,6 +329,7 @@ class NCCLGroup : public GroupImpl { Stream stream, ncclDataType_t dt, ncclRedOp_t op) { + CHECK_NCCL(ncclAllReduce( input.data(), output.data(), From 043c37cccd88d5043a4a68634e56dd496e93640c Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Fri, 20 Jun 2025 16:07:41 +0200 Subject: [PATCH 6/7] Use last cuda stream instead of new one --- mlx/distributed/nccl/nccl.cpp | 71 +++++++++++++++++++---------------- 1 file changed, 39 insertions(+), 32 deletions(-) diff --git a/mlx/distributed/nccl/nccl.cpp b/mlx/distributed/nccl/nccl.cpp index f6fa28ad8..02b1fc20c 100644 --- a/mlx/distributed/nccl/nccl.cpp +++ b/mlx/distributed/nccl/nccl.cpp @@ -80,7 +80,7 @@ inline void bootstrap_unique_id( int rank, int size, const std::string& initMethod) { - + // Parse the init method to extract the host and port if (initMethod.rfind("tcp://", 0) != 0) throw; auto hostport = initMethod.substr(6); @@ -89,8 +89,10 @@ inline void bootstrap_unique_id( int port = std::stoi(hostport.substr(colon + 1)); if (rank == 0) { + // create a unique id on the rank 0 CHECK_NCCL(ncclGetUniqueId(&id)); + // create a socket to send the unique id to all other ranks int sock = socket(AF_INET, SOCK_STREAM, 0); if (sock < 0) { @@ -105,6 +107,9 @@ inline void bootstrap_unique_id( serv.sin_port = htons(port); int reuse = 1; + // Without this, if rank-0 crashes or restarts process quickly, + // the OS might refuse to let binding to the same port, so reuse + if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) { std::ostringstream msg; msg << "[nccl] setsockopt() failed: " << strerror(errno); @@ -244,9 +249,7 @@ class NCCLGroup : public GroupImpl { int ndev; CHECK_CUDA(cudaGetDeviceCount(&ndev)); CHECK_CUDA(cudaSetDevice(rank_ % ndev)); - CHECK_CUDA(cudaStreamCreate(&stream_)); - - detail::bootstrapUniqueId(uniqueId_, rank_, size_, initMethod_); + detail::bootstrap_unique_id(uniqueId_, rank_, size_, initMethod_); CHECK_NCCL(ncclCommInitRank(&comm_, size_, uniqueId_, rank_)); initialized_ = true; } @@ -254,7 +257,6 @@ class NCCLGroup : public GroupImpl { ~NCCLGroup() { ncclCommDestroy(comm_); ncclGroupEnd(); - cudaStreamDestroy(stream_); initialized_ = false; } @@ -267,13 +269,9 @@ class NCCLGroup : public GroupImpl { } void all_sum(const array& input, array& output, Stream stream) override { - if (input.size() != output.size()) { - throw std::runtime_error( - "[nccl] Input and output arrays must have the same size."); - } detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { using T = typename decltype(type_tag)::type; - all_reduce_impl(input, output, stream, dt, ncclSum); + detail::all_reduce_impl(input, output, stream, dt, ncclSum); }); } @@ -282,29 +280,45 @@ class NCCLGroup : public GroupImpl { } void all_gather(const array& input, array& output, Stream stream) override { - if (input.size() != output.size() / size_) { - throw std::runtime_error( - "[nccl] Input size must match output size divided by group size."); - } + detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { + using T = typename decltype(type_tag)::type; + CHECK_NCCL(ncclAllGather( + input.data(), + output.data(), + input.size(), + dt, + comm_, + cu::get_stream(stream).last_cuda_stream())); + }); } void send(const array& input, int dst, Stream stream) override { - if (input.size() == 0) { - return; // Nothing to send - } + detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { + using T = typename decltype(type_tag)::type; + CHECK_NCCL(ncclSend( + input.data(), + input.size(), + dt, + dst, + comm_, + cu::get_stream(stream).last_cuda_stream())); + }); } void recv(array& output, int src, Stream stream) override { - if (output.size() == 0) { - return; // Nothing to receive - } + detail::dispatch_dtype(output, [&](auto type_tag, ncclDataType_t dt) { + using T = typename decltype(type_tag)::type; + CHECK_NCCL(ncclRecv( + output.data(), + output.size(), + dt, + src, + comm_, + cu::get_stream(stream).last_cuda_stream())); + }); } void all_max(const array& input, array& output, Stream stream) override { - if (input.size() != output.size()) { - throw std::runtime_error( - "[nccl] Input and output arrays must have the same size."); - } detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { using T = typename decltype(type_tag)::type; all_reduce_impl(input, output, stream, dt, ncclMax); @@ -312,10 +326,6 @@ class NCCLGroup : public GroupImpl { } void all_min(const array& input, array& output, Stream stream) override { - if (input.size() != output.size()) { - throw std::runtime_error( - "[nccl] Input and output arrays must have the same size."); - } detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { using T = typename decltype(type_tag)::type; all_reduce_impl(input, output, stream, dt, ncclMin); @@ -329,7 +339,6 @@ class NCCLGroup : public GroupImpl { Stream stream, ncclDataType_t dt, ncclRedOp_t op) { - CHECK_NCCL(ncclAllReduce( input.data(), output.data(), @@ -337,15 +346,13 @@ class NCCLGroup : public GroupImpl { dt, op, comm_, - stream_)); - cudaStreamSynchronize(stream_); + cu::get_stream(stream).last_cuda_stream())); } int rank_, size_; std::string initMethod_; ncclUniqueId uniqueId_; ncclComm_t comm_; - cudaStream_t stream_; bool initialized_ = false; }; From bc6f00c00e7544ac4d2eddb6ea5c0c0fda771de1 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Tue, 5 Aug 2025 02:00:52 +0200 Subject: [PATCH 7/7] Changed nccl reduction to be a parrt of cuda grapph --- cmake/FindNCCL.cmake | 14 +-- mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/distributed.cu | 87 +++++++++++++++++++ .../cuda/iterators/strided_iterator.cuh | 60 +++++++++++++ mlx/backend/cuda/matmul.cpp | 1 + mlx/backend/cuda/primitives.cpp | 33 ++----- mlx/backend/cuda/reduce/segmented_reduce.cu | 84 ++++++++++++++++++ mlx/distributed/distributed.cpp | 3 +- mlx/distributed/nccl/nccl.cpp | 26 ++++-- 9 files changed, 264 insertions(+), 45 deletions(-) create mode 100644 mlx/backend/cuda/distributed.cu create mode 100644 mlx/backend/cuda/iterators/strided_iterator.cuh create mode 100644 mlx/backend/cuda/reduce/segmented_reduce.cu diff --git a/cmake/FindNCCL.cmake b/cmake/FindNCCL.cmake index 7f8791476..917640f0d 100644 --- a/cmake/FindNCCL.cmake +++ b/cmake/FindNCCL.cmake @@ -1,15 +1,5 @@ -# Find the nccl libraries -# -# The following variables are optionally searched for defaults NCCL_ROOT_DIR: -# Base directory where all NCCL components are found NCCL_INCLUDE_DIR: Directory -# where NCCL header is found NCCL_LIB_DIR: Directory where NCCL library is found -# -# The following are set after configuration is done: NCCL_FOUND -# NCCL_INCLUDE_DIRS NCCL_LIBRARIES -# -# The path hints include CUDA_TOOLKIT_ROOT_DIR seeing as some folks install NCCL -# in the same location as the CUDA toolkit. See -# https://github.com/caffe2/caffe2/issues/1601 +# FindNCCL.cmake +# This module finds the NVIDIA NCCL library and its include directories. set(NCCL_ROOT_DIR $ENV{NCCL_ROOT_DIR} diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 8c1b999e9..5e0f970da 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -19,6 +19,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp diff --git a/mlx/backend/cuda/distributed.cu b/mlx/backend/cuda/distributed.cu new file mode 100644 index 000000000..df0fe4539 --- /dev/null +++ b/mlx/backend/cuda/distributed.cu @@ -0,0 +1,87 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/distributed/primitives.h" +#include "mlx/primitives.h" +#include "mlx/backend/cuda/kernel_utils.cuh" + + +#include + +namespace mlx::core { + namespace distributed { + void AllReduce::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + // Here I assume for now that in is donatable and contiguous. + // TODO + assert(inputs.size() == 1); + assert(outputs.size() == 1); + + auto& input = inputs[0]; + auto& output = outputs[0]; + + auto& encoder = cu::get_command_encoder(stream()); + output.set_data(allocator::malloc(output.nbytes())); + + encoder.set_input_array(input); + encoder.set_output_array(output); + + auto capture = encoder.capture_context(); + auto& s = stream(); + + switch (reduce_type_) { + case Sum: + distributed::detail::all_sum(group(), input, output, s); + break; + case Max: + distributed::detail::all_max(group(), input, output, s); + break; + case Min: + distributed::detail::all_min(group(), input, output, s); + break; + default: + throw std::runtime_error( + "Only all reduce sum, max, and min are supported."); + } + } + + void Send::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + // Here FOR NOW I assume that it is always row_contigious + // because not sure how to copy correctly + // TODO + assert(inputs.size() == 1); + assert(outputs.size() == 1); + + distributed::detail::send(group(), inputs[0], dst_, stream()); + outputs[0].copy_shared_buffer(inputs[0]); + } + + void Recv::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 0); + assert(outputs.size() == 1); + outputs[0].set_data(allocator::malloc(outputs[0].nbytes())); + distributed::detail::recv(group(), outputs[0], src_, stream()); + } + + void AllGather::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + // Here FOR NOW I assume that it is always row_contigious + // because not sure how to copy correctly + // TODO + assert(inputs.size() == 1); + assert(outputs.size() == 1); + + auto& input = inputs[0]; + auto& output = outputs[0]; + + output.copy_shared_buffer(input); + distributed::detail::all_gather(group(), input, output, stream()); + } + }// namespace distributed +} \ No newline at end of file diff --git a/mlx/backend/cuda/iterators/strided_iterator.cuh b/mlx/backend/cuda/iterators/strided_iterator.cuh new file mode 100644 index 000000000..3ef8d66bd --- /dev/null +++ b/mlx/backend/cuda/iterators/strided_iterator.cuh @@ -0,0 +1,60 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +namespace mlx::core::cu { + +// RandomAccessIterator for strided access to array entries. +template +class strided_iterator + : public thrust:: + iterator_adaptor, Iterator> { + public: + using super_t = + thrust::iterator_adaptor, Iterator>; + + using reference = typename super_t::reference; + using difference_type = typename super_t::difference_type; + + __host__ __device__ strided_iterator(Iterator it, Stride stride) + : super_t(it), stride_(stride) {} + + __host__ __device__ Stride stride() const { + return stride_; + } + + private: + friend class thrust::iterator_core_access; + + __host__ __device__ bool equal(const strided_iterator& other) const { + return this->base() == other.base(); + } + + __host__ __device__ void advance(difference_type n) { + this->base_reference() += n * stride_; + } + + __host__ __device__ void increment() { + this->base_reference() += stride_; + } + + __host__ __device__ void decrement() { + this->base_reference() -= stride_; + } + + __host__ __device__ difference_type + distance_to(const strided_iterator& other) const { + const difference_type dist = other.base() - this->base(); + _CCCL_ASSERT( + dist % stride() == 0, + "Underlying iterator difference must be divisible by the stride"); + return dist / stride(); + } + + Stride stride_; +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 283aaaf2e..93346d887 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -6,6 +6,7 @@ #include "mlx/backend/cuda/gemms/gemv.h" #include "mlx/backend/gpu/copy.h" #include "mlx/primitives.h" +#include "mlx/utils.h" #include #include diff --git a/mlx/backend/cuda/primitives.cpp b/mlx/backend/cuda/primitives.cpp index c471fa8c2..08a457eaf 100644 --- a/mlx/backend/cuda/primitives.cpp +++ b/mlx/backend/cuda/primitives.cpp @@ -17,28 +17,6 @@ bool fast::ScaledDotProductAttention::use_fallback( return true; } -namespace distributed { -void AllReduce::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { - // Here I assume for now that in is donatable and contiguous. - // TODO - - auto& input = inputs[0]; - auto& output = outputs[0]; - - output.copy_shared_buffer(input); - auto& s = stream(); - switch (reduce_type_) { - case Sum: - distributed::detail::all_sum(group(), input, output, s); - break; - default: - throw std::runtime_error("Only all reduce sum is supported for now"); - } -} -} // namespace distributed - #define NO_GPU_MULTI(func) \ void func::eval_gpu( \ const std::vector& inputs, std::vector& outputs) { \ @@ -79,10 +57,11 @@ NO_GPU(ScaledDotProductAttention) NO_GPU_MULTI(CustomKernel) } // namespace fast -namespace distributed { -NO_GPU_MULTI(AllGather) -NO_GPU_MULTI(Send) -NO_GPU_MULTI(Recv) -} // namespace distributed +// namespace distributed { +// NO_GPU_MULTI(AllReduce) +// NO_GPU_MULTI(AllGather) +// NO_GPU_MULTI(Send) +// NO_GPU_MULTI(Recv) +// } // namespace distributed } // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/segmented_reduce.cu b/mlx/backend/cuda/reduce/segmented_reduce.cu new file mode 100644 index 000000000..114d71809 --- /dev/null +++ b/mlx/backend/cuda/reduce/segmented_reduce.cu @@ -0,0 +1,84 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/cast_op.cuh" +#include "mlx/backend/cuda/reduce/reduce.cuh" + +#include +#include +#include + +namespace mlx::core { + +template +void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) { + // Allocate temporary storage. + size_t size; + CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...)); + array temp(allocator::malloc(size), {static_cast(size)}, uint8); + encoder.add_temporary(temp); + // Run op. + CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data(), size, args...)); +} + +template +void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) { + // Allocate temporary storage. + size_t size; + CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...)); + array temp(allocator::malloc(size), {static_cast(size)}, uint8); + encoder.add_temporary(temp); + // Run op. + CHECK_CUDA_ERROR( + cub::DeviceSegmentedReduce::Reduce(temp.data(), size, args...)); +} + +struct MultiplyOp { + int factor; + __device__ int operator()(int i) { + return i * factor; + } +}; + +void segmented_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan) { + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using InType = cuda_type_t; + using OutType = cu::ReduceResult::type; + auto in_iter = cu::make_cast_iterator( + thrust::device_pointer_cast(in.data())); + auto out_ptr = thrust::device_pointer_cast(out.data()); + auto init = cu::ReduceInit::value(); + + if (plan.type == ContiguousAllReduce) { + cub_all_reduce( + encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream); + } else if (plan.type == ContiguousReduce) { + auto offsets = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()}); + cub_segmented_reduce( + encoder, + in_iter, + out_ptr, + out.size(), + offsets, + offsets + 1, + OP(), + init, + stream); + } else { + throw std::runtime_error("Unsupported plan in segmented_reduce."); + } + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/distributed/distributed.cpp b/mlx/distributed/distributed.cpp index f791ee29e..a65329588 100644 --- a/mlx/distributed/distributed.cpp +++ b/mlx/distributed/distributed.cpp @@ -2,6 +2,7 @@ #include +#include #include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/mpi/mpi.h" @@ -81,7 +82,7 @@ class EmptyGroup : public GroupImpl { } // namespace detail bool is_available() { - return mpi::is_available() || ring::is_available(); + return mpi::is_available() || ring::is_available() || nccl::is_available(); } int Group::rank() const { diff --git a/mlx/distributed/nccl/nccl.cpp b/mlx/distributed/nccl/nccl.cpp index 02b1fc20c..c29851271 100644 --- a/mlx/distributed/nccl/nccl.cpp +++ b/mlx/distributed/nccl/nccl.cpp @@ -271,7 +271,12 @@ class NCCLGroup : public GroupImpl { void all_sum(const array& input, array& output, Stream stream) override { detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { using T = typename decltype(type_tag)::type; - detail::all_reduce_impl(input, output, stream, dt, ncclSum); + all_reduce_impl( + input, + output, + stream, + dt, + ncclSum); }); } @@ -281,6 +286,8 @@ class NCCLGroup : public GroupImpl { void all_gather(const array& input, array& output, Stream stream) override { detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { + auto& encoder = cu::get_command_encoder(stream); + using T = typename decltype(type_tag)::type; CHECK_NCCL(ncclAllGather( input.data(), @@ -288,12 +295,14 @@ class NCCLGroup : public GroupImpl { input.size(), dt, comm_, - cu::get_stream(stream).last_cuda_stream())); + encoder.stream())); }); } void send(const array& input, int dst, Stream stream) override { detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { + auto& encoder = cu::get_command_encoder(stream); + using T = typename decltype(type_tag)::type; CHECK_NCCL(ncclSend( input.data(), @@ -301,20 +310,22 @@ class NCCLGroup : public GroupImpl { dt, dst, comm_, - cu::get_stream(stream).last_cuda_stream())); + encoder.stream())); }); } void recv(array& output, int src, Stream stream) override { detail::dispatch_dtype(output, [&](auto type_tag, ncclDataType_t dt) { using T = typename decltype(type_tag)::type; + auto& encoder = cu::get_command_encoder(stream); + CHECK_NCCL(ncclRecv( output.data(), output.size(), dt, src, comm_, - cu::get_stream(stream).last_cuda_stream())); + encoder.stream())); }); } @@ -339,6 +350,9 @@ class NCCLGroup : public GroupImpl { Stream stream, ncclDataType_t dt, ncclRedOp_t op) { + + auto& encoder = cu::get_command_encoder(stream); + CHECK_NCCL(ncclAllReduce( input.data(), output.data(), @@ -346,7 +360,9 @@ class NCCLGroup : public GroupImpl { dt, op, comm_, - cu::get_stream(stream).last_cuda_stream())); + encoder.stream() + )); + } int rank_, size_;