From 9392fc3f88b8a7c2d8b13f0f4bb76e63dacfbab6 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Thu, 21 Aug 2025 20:56:15 +0200 Subject: [PATCH] NCCL backend (#2476) --- .circleci/config.yml | 1 + cmake/FindNCCL.cmake | 54 ++++ docs/src/install.rst | 2 +- mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/distributed.cu | 51 ++++ mlx/backend/cuda/primitives.cpp | 1 - mlx/distributed/CMakeLists.txt | 1 + mlx/distributed/distributed.cpp | 13 +- mlx/distributed/distributed.h | 1 + mlx/distributed/distributed_impl.h | 8 + mlx/distributed/mpi/mpi.cpp | 4 + mlx/distributed/nccl/CMakeLists.txt | 8 + mlx/distributed/nccl/nccl.cpp | 359 ++++++++++++++++++++++++++ mlx/distributed/nccl/nccl.h | 12 + mlx/distributed/nccl/no_nccl.cpp | 20 ++ mlx/distributed/ops.cpp | 30 ++- mlx/distributed/ring/ring.cpp | 4 + python/mlx/distributed_run.py | 57 +++- python/mlx/nn/utils.py | 4 +- python/src/distributed.cpp | 2 +- python/tests/nccl_test_distributed.py | 284 ++++++++++++++++++++ 21 files changed, 897 insertions(+), 20 deletions(-) create mode 100644 cmake/FindNCCL.cmake create mode 100644 mlx/backend/cuda/distributed.cu create mode 100644 mlx/distributed/nccl/CMakeLists.txt create mode 100644 mlx/distributed/nccl/nccl.cpp create mode 100644 mlx/distributed/nccl/nccl.h create mode 100644 mlx/distributed/nccl/no_nccl.cpp create mode 100644 python/tests/nccl_test_distributed.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 7472c58f1..03987d39c 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -222,6 +222,7 @@ jobs: sudo apt-get update sudo apt-get install libcudnn9-dev-cuda-12 sudo apt-get install libblas-dev liblapack-dev liblapacke-dev + sudo apt-get install libnccl2 libnccl-dev curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf - sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache rm -rf ccache-4.11.3-linux-x86_64 diff --git a/cmake/FindNCCL.cmake b/cmake/FindNCCL.cmake new file mode 100644 index 000000000..b31893241 --- /dev/null +++ b/cmake/FindNCCL.cmake @@ -0,0 +1,54 @@ +# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include +# directories. + +set(NCCL_ROOT_DIR + $ENV{NCCL_ROOT_DIR} + CACHE PATH "Folder contains NVIDIA NCCL") + +find_path( + NCCL_INCLUDE_DIRS + NAMES nccl.h + HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include + ${CUDA_TOOLKIT_ROOT_DIR}/include) + +if($ENV{USE_STATIC_NCCL}) + message( + STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library") + set(NCCL_LIBNAME "libnccl_static.a") +else() + set(NCCL_LIBNAME "nccl") +endif() + +find_library( + NCCL_LIBRARIES + NAMES ${NCCL_LIBNAME} + HINTS ${NCCL_LIB_DIR} + ${NCCL_ROOT_DIR} + ${NCCL_ROOT_DIR}/lib + ${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu + ${NCCL_ROOT_DIR}/lib64 + ${CUDA_TOOLKIT_ROOT_DIR}/lib + ${CUDA_TOOLKIT_ROOT_DIR}/lib64) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS + NCCL_LIBRARIES) + +if(NCCL_FOUND) + set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h") + message( + STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}") + file( + STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED + REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$" + LIMIT_COUNT 1) + if(NCCL_MAJOR_VERSION_DEFINED) + string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" "" + NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED}) + message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}") + endif() + message( + STATUS + "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})") + mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES) +endif() diff --git a/docs/src/install.rst b/docs/src/install.rst index 1e7a015ca..da7470908 100644 --- a/docs/src/install.rst +++ b/docs/src/install.rst @@ -271,7 +271,7 @@ and the CUDA toolkit. For example on Ubuntu, run the following: dpkg -i cuda-keyring_1.1-1_all.deb apt-get update -y apt-get -y install cuda-toolkit-12-9 - apt-get install libblas-dev liblapack-dev liblapacke-dev -y + apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y When building either the Python or C++ APIs make sure to pass the cmake flag diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 2e12c8c3e..8c3885384 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -22,6 +22,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp diff --git a/mlx/backend/cuda/distributed.cu b/mlx/backend/cuda/distributed.cu new file mode 100644 index 000000000..2cdf615f5 --- /dev/null +++ b/mlx/backend/cuda/distributed.cu @@ -0,0 +1,51 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/distributed/primitives.h" +#include "mlx/primitives.h" + +#include + +namespace mlx::core { +namespace distributed { +void AllReduce::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 1); + assert(outputs.size() == 1); + + auto& input = inputs[0]; + auto& output = outputs[0]; + + auto& encoder = cu::get_command_encoder(stream()); + + if (input.is_donatable()) { + output.copy_shared_buffer(input); + } else { + output.set_data(allocator::malloc(output.nbytes())); + } + + encoder.set_input_array(input); + encoder.set_output_array(output); + + auto capture = encoder.capture_context(); + auto& s = stream(); + + switch (reduce_type_) { + case Sum: + distributed::detail::all_sum(group(), input, output, s); + break; + case Max: + distributed::detail::all_max(group(), input, output, s); + break; + case Min: + distributed::detail::all_min(group(), input, output, s); + break; + default: + throw std::runtime_error( + "Only all reduce sum, max, and min are supported."); + } +} +} // namespace distributed +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/cuda/primitives.cpp b/mlx/backend/cuda/primitives.cpp index aa20f0128..f9a594ab8 100644 --- a/mlx/backend/cuda/primitives.cpp +++ b/mlx/backend/cuda/primitives.cpp @@ -42,7 +42,6 @@ NO_GPU_MULTI(Eig) NO_GPU_MULTI(Eigh) namespace distributed { -NO_GPU_MULTI(AllReduce) NO_GPU_MULTI(AllGather) NO_GPU_MULTI(Send) NO_GPU_MULTI(Recv) diff --git a/mlx/distributed/CMakeLists.txt b/mlx/distributed/CMakeLists.txt index 8e16bd40d..b7762f6a7 100644 --- a/mlx/distributed/CMakeLists.txt +++ b/mlx/distributed/CMakeLists.txt @@ -6,3 +6,4 @@ target_sources( add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl) diff --git a/mlx/distributed/distributed.cpp b/mlx/distributed/distributed.cpp index cc01e6090..44205e87e 100644 --- a/mlx/distributed/distributed.cpp +++ b/mlx/distributed/distributed.cpp @@ -5,12 +5,17 @@ #include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/mpi/mpi.h" +#include "mlx/distributed/nccl/nccl.h" #include "mlx/distributed/ring/ring.h" namespace mlx::core::distributed { namespace detail { +Stream communication_stream(Group group, StreamOrDevice s /* = {} */) { + return group.raw_group()->communication_stream(s); +} + void all_sum(Group group, const array& input, array& output, Stream stream) { group.raw_group()->all_sum(input, output, stream); } @@ -37,6 +42,10 @@ void recv(Group group, array& out, int src, Stream stream) { class EmptyGroup : public GroupImpl { public: + Stream communication_stream(StreamOrDevice s) override { + return to_stream(s); + } + int rank() override { return 0; } @@ -80,7 +89,7 @@ class EmptyGroup : public GroupImpl { } // namespace detail bool is_available() { - return mpi::is_available() || ring::is_available(); + return mpi::is_available() || ring::is_available() || nccl::is_available(); } int Group::rank() const { @@ -111,6 +120,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) { group = mpi::init(strict); } else if (bk == "ring") { group = ring::init(strict); + } else if (bk == "nccl") { + group = nccl::init(strict); } else if (bk == "any") { group = ring::init(false); bk_ = "ring"; diff --git a/mlx/distributed/distributed.h b/mlx/distributed/distributed.h index 1f1713866..fa5c42a1f 100644 --- a/mlx/distributed/distributed.h +++ b/mlx/distributed/distributed.h @@ -5,6 +5,7 @@ #include #include "mlx/array.h" +#include "mlx/utils.h" namespace mlx::core::distributed { diff --git a/mlx/distributed/distributed_impl.h b/mlx/distributed/distributed_impl.h index 8b0327131..c90b0ba47 100644 --- a/mlx/distributed/distributed_impl.h +++ b/mlx/distributed/distributed_impl.h @@ -13,10 +13,15 @@ class GroupImpl { public: virtual ~GroupImpl() {} + // Choose the stream this communication group can operate on + virtual Stream communication_stream(StreamOrDevice s = {}) = 0; + + // Group operations virtual int rank() = 0; virtual int size() = 0; virtual std::shared_ptr split(int color, int key = -1) = 0; + // Actual communication operations virtual void all_sum(const array& input, array& output, Stream stream) = 0; virtual void all_gather(const array& input, array& output, Stream stream) = 0; virtual void send(const array& input, int dst, Stream stream) = 0; @@ -25,6 +30,9 @@ class GroupImpl { virtual void all_min(const array& input, array& output, Stream stream) = 0; }; +/* Define the MLX stream that the communication should happen in. */ +Stream communication_stream(Group group, StreamOrDevice s = {}); + /* Perform an all reduce sum operation */ void all_sum(Group group, const array& input, array& output, Stream stream); diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index 6a440c319..494fb02dc 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -349,6 +349,10 @@ class MPIGroup : public GroupImpl { } } + Stream communication_stream(StreamOrDevice s) override { + return to_stream(s, Device::cpu); + } + int rank() override { if (rank_ < 0) { mpi().rank(comm_, &rank_); diff --git a/mlx/distributed/nccl/CMakeLists.txt b/mlx/distributed/nccl/CMakeLists.txt new file mode 100644 index 000000000..2f764c6ac --- /dev/null +++ b/mlx/distributed/nccl/CMakeLists.txt @@ -0,0 +1,8 @@ +if(MLX_BUILD_CUDA) + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nccl.cpp) + find_package(NCCL REQUIRED) + target_link_libraries(mlx PRIVATE ${NCCL_LIBRARIES}) + target_include_directories(mlx PRIVATE ${NCCL_INCLUDE_DIRS}) +else() + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_nccl.cpp) +endif() diff --git a/mlx/distributed/nccl/nccl.cpp b/mlx/distributed/nccl/nccl.cpp new file mode 100644 index 000000000..43af9c724 --- /dev/null +++ b/mlx/distributed/nccl/nccl.cpp @@ -0,0 +1,359 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mlx/backend/cuda/device.h" +#include "mlx/distributed/distributed.h" +#include "mlx/distributed/distributed_impl.h" +#include "mlx/dtype_utils.h" +#include "mlx/utils.h" + +namespace mlx::core::distributed::nccl { + +#define CHECK_CUDA(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + fprintf( \ + stderr, \ + "CUDA error %s:%d '%s'\n", \ + __FILE__, \ + __LINE__, \ + cudaGetErrorString(e)); \ + exit(1); \ + } \ + } while (0) + +#define CHECK_NCCL(cmd) \ + do { \ + ncclResult_t r = cmd; \ + if (r != ncclSuccess) { \ + fprintf( \ + stderr, \ + "NCCL error %s:%d '%s'\n", \ + __FILE__, \ + __LINE__, \ + ncclGetErrorString(r)); \ + exit(1); \ + } \ + } while (0) + +#define MLX_NCCL_TYPE_LIST(X) \ + X(int8_t, ncclChar) \ + X(uint8_t, ncclUint8) \ + X(int32_t, ncclInt) \ + X(uint32_t, ncclUint32) \ + X(int64_t, ncclInt64) \ + X(uint64_t, ncclUint64) \ + X(float16_t, ncclHalf) \ + X(bfloat16_t, ncclBfloat16) \ + X(float, ncclFloat) \ + X(double, ncclDouble) + +template +struct nccl_map { + static constexpr bool ok = false; // default: unsupported +}; + +#define MLX_DEF_NCCL_MAP(T, E) \ + template <> \ + struct nccl_map { \ + static constexpr bool ok = true; \ + static constexpr ncclDataType_t value = E; \ + }; + +MLX_NCCL_TYPE_LIST(MLX_DEF_NCCL_MAP) +#undef MLX_DEF_NCCL_MAP + +namespace detail { + +template +void dispatch_dtype(const array& arr, F&& f) { + dispatch_all_types(arr.dtype(), [&](auto type_tag) { + using T = MLX_GET_TYPE(type_tag); + if constexpr (nccl_map::ok) { + f(type_tag, nccl_map::value); + } else { + throw std::invalid_argument("[nccl] Unknown or unsupported dtype"); + } + }); +} + +inline void sendAll(int sock, const void* buf, size_t len) { + const char* ptr = reinterpret_cast(buf); + while (len > 0) { + ssize_t sent = send(sock, ptr, len, 0); + if (sent <= 0) { + perror("send"); + exit(1); + } + ptr += sent; + len -= sent; + } +} + +inline void recvAll(int sock, void* buf, size_t len) { + char* ptr = reinterpret_cast(buf); + while (len > 0) { + ssize_t rec = recv(sock, ptr, len, 0); + if (rec <= 0) { + perror("recv"); + exit(1); + } + ptr += rec; + len -= rec; + } +} + +inline void bootstrap_unique_id( + ncclUniqueId& id, + int rank, + int size, + const std::string& initMethod) { + // Parse the init method to extract the host and port + if (initMethod.rfind("tcp://", 0) != 0) + throw; + auto hostport = initMethod.substr(6); + auto colon = hostport.find(':'); + std::string host = hostport.substr(0, colon); + int port = std::stoi(hostport.substr(colon + 1)); + + if (rank == 0) { + // create a unique id on the rank 0 + CHECK_NCCL(ncclGetUniqueId(&id)); + + // create a socket to send the unique id to all other ranks + int sock = socket(AF_INET, SOCK_STREAM, 0); + + if (sock < 0) { + std::ostringstream msg; + msg << "[nccl] Couldn't create socket (error: " << errno << ")"; + throw std::runtime_error(msg.str()); + } + + sockaddr_in serv = {}; + serv.sin_family = AF_INET; + serv.sin_addr.s_addr = htonl(INADDR_ANY); + serv.sin_port = htons(port); + + int reuse = 1; + // Without this, if rank-0 crashes or restarts process quickly, + // the OS might refuse to let binding to the same port, so reuse + + if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) { + std::ostringstream msg; + msg << "[nccl] setsockopt() failed: " << strerror(errno); + throw std::runtime_error(msg.str()); + } + + if (bind(sock, reinterpret_cast(&serv), sizeof(serv)) < 0) { + std::ostringstream msg; + msg << "[nccl] bind() failed: " << strerror(errno); + throw std::runtime_error(msg.str()); + } + if (listen(sock, size - 1) < 0) { + std::ostringstream msg; + msg << "[nccl] listen() failed: " << strerror(errno); + throw std::runtime_error(msg.str()); + } + + for (int peer = 1; peer < size; ++peer) { + int conn = accept(sock, nullptr, nullptr); + if (conn < 0) { + std::ostringstream msg; + msg << "[nccl] accept() failed: " << strerror(errno); + throw std::runtime_error(msg.str()); + } + sendAll(conn, &id, sizeof(id)); + close(conn); + } + close(sock); + + } else { + // Here just wanted to make show that rank 0 has enough time to bind + // so we will retry to connect until max attempts + + int sock = socket(AF_INET, SOCK_STREAM, 0); + if (sock < 0) { + std::ostringstream msg; + msg << "[nccl] socket() failed: " << strerror(errno); + throw std::runtime_error(msg.str()); + } + + hostent* he = gethostbyname(host.c_str()); + if (!he) { + throw std::runtime_error("[nccl] lookup failed for host: " + host); + } + sockaddr_in serv = {}; + serv.sin_family = AF_INET; + memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length); + serv.sin_port = htons(port); + + const int max_retries = 30; + int attempt = 0; + bool connected = false; + + for (attempt = 0; attempt < max_retries; ++attempt) { + if (connect(sock, reinterpret_cast(&serv), sizeof(serv)) == + 0) { + connected = true; + std::cout << "[Rank " << rank << "] Connected successfully on attempt " + << attempt + 1 << std::endl; + break; + } + if (errno != ECONNREFUSED) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + } + + if (!connected) { + std::ostringstream msg; + msg << "[Rank " << rank << "] connect() failed after " << attempt + << " retries: " << strerror(errno); + close(sock); + throw std::runtime_error(msg.str()); + } + recvAll(sock, &id, sizeof(id)); + close(sock); + } +} + +} // namespace detail + +using GroupImpl = mlx::core::distributed::detail::GroupImpl; +class NCCLGroup : public GroupImpl { + public: + NCCLGroup(int worldRank, int worldSize, const std::string initMethod) + : rank_(worldRank), + size_(worldSize), + comm_(nullptr), + initMethod_(initMethod) { + if (initialized_) + return; + int ndev; + CHECK_CUDA(cudaGetDeviceCount(&ndev)); + CHECK_CUDA(cudaSetDevice(rank_ % ndev)); + detail::bootstrap_unique_id(uniqueId_, rank_, size_, initMethod_); + CHECK_NCCL(ncclCommInitRank(&comm_, size_, uniqueId_, rank_)); + initialized_ = true; + } + + ~NCCLGroup() { + ncclCommDestroy(comm_); + ncclGroupEnd(); + initialized_ = false; + } + + Stream communication_stream(StreamOrDevice s) override { + return to_stream(s, Device::gpu); + } + + int rank() override { + return rank_; + } + + int size() override { + return size_; + } + + void all_sum(const array& input, array& output, Stream stream) override { + detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { + using T = typename decltype(type_tag)::type; + all_reduce_impl(input, output, stream, dt, ncclSum); + }); + } + + virtual std::shared_ptr split(int color, int key = -1) override { + throw std::runtime_error("[nccl] Group split not supported."); + } + + void all_gather(const array& input, array& output, Stream stream) override { + throw std::runtime_error( + "[nccl] All gather not supported in NCCL backend."); + } + + void send(const array& input, int dst, Stream stream) override { + throw std::runtime_error("[nccl] Send not supported in NCCL backend."); + } + + void recv(array& output, int src, Stream stream) override { + throw std::runtime_error("[nccl] Recv not supported in NCCL backend."); + } + + void all_max(const array& input, array& output, Stream stream) override { + throw std::runtime_error("[nccl] All max not supported in NCCL backend."); + } + + void all_min(const array& input, array& output, Stream stream) override { + throw std::runtime_error("[nccl] All min not supported in NCCL backend."); + } + + template + void all_reduce_impl( + const array& input, + array& output, + Stream stream, + ncclDataType_t dt, + ncclRedOp_t op) { + auto& encoder = cu::get_command_encoder(stream); + + CHECK_NCCL(ncclAllReduce( + input.data(), + output.data(), + input.size(), + dt, + op, + comm_, + encoder.stream())); + } + + int rank_, size_; + std::string initMethod_; + ncclUniqueId uniqueId_; + ncclComm_t comm_; + bool initialized_ = false; +}; + +bool is_available() { + return true; +} + +namespace detail { +static std::string get_env_var_or_throw(const char* env_var_name) { + const char* value = std::getenv(env_var_name); + if (value == nullptr) { + std::ostringstream msg; + msg << "[nccl] Required environment variable '" << env_var_name + << "' is not set. " + << "Please set it before initializing the distributed backend."; + throw std::runtime_error(msg.str()); + } + return std::string(value); +} +} // namespace detail + +std::shared_ptr init(bool strict /* = false */) { + std::string host = detail::get_env_var_or_throw("NCCL_HOST_IP"); + std::string port = detail::get_env_var_or_throw("NCCL_PORT"); + std::string rank_str = detail::get_env_var_or_throw("MLX_RANK"); + std::string n_nodes_str = detail::get_env_var_or_throw("MLX_WORLD_SIZE"); + + int rank = std::stoi(rank_str); + int n_nodes = std::stoi(n_nodes_str); + std::string init_method = "tcp://" + host + ":" + port; + + return std::make_shared(rank, n_nodes, init_method); +} +} // namespace mlx::core::distributed::nccl diff --git a/mlx/distributed/nccl/nccl.h b/mlx/distributed/nccl/nccl.h new file mode 100644 index 000000000..5370d2daf --- /dev/null +++ b/mlx/distributed/nccl/nccl.h @@ -0,0 +1,12 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/distributed/distributed.h" + +namespace mlx::core::distributed::nccl { + +using GroupImpl = mlx::core::distributed::detail::GroupImpl; + +bool is_available(); +std::shared_ptr init(bool strict = false); + +} // namespace mlx::core::distributed::nccl diff --git a/mlx/distributed/nccl/no_nccl.cpp b/mlx/distributed/nccl/no_nccl.cpp new file mode 100644 index 000000000..1be256a11 --- /dev/null +++ b/mlx/distributed/nccl/no_nccl.cpp @@ -0,0 +1,20 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/distributed/nccl/nccl.h" + +namespace mlx::core::distributed::nccl { + +using GroupImpl = mlx::core::distributed::detail::GroupImpl; + +bool is_available() { + return false; +} + +std::shared_ptr init(bool strict /* = false */) { + if (strict) { + throw std::runtime_error("Cannot initialize nccl distributed backend."); + } + return nullptr; +} + +} // namespace mlx::core::distributed::nccl diff --git a/mlx/distributed/ops.cpp b/mlx/distributed/ops.cpp index 0a5114805..157bc2612 100644 --- a/mlx/distributed/ops.cpp +++ b/mlx/distributed/ops.cpp @@ -2,6 +2,9 @@ #include +#include "mlx/backend/cuda/cuda.h" +#include "mlx/backend/metal/metal.h" +#include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/ops.h" #include "mlx/distributed/primitives.h" @@ -28,11 +31,12 @@ array all_sum( if (group.size() == 1) { return x; } + auto stream = detail::communication_stream(group, s); + return array( x.shape(), x.dtype(), - std::make_shared( - to_stream(s, Device::cpu), group, AllReduce::Sum), + std::make_shared(stream, group, AllReduce::Sum), {x}); } @@ -45,11 +49,12 @@ array all_max( if (group.size() == 1) { return x; } + auto stream = detail::communication_stream(group, s); + return array( x.shape(), x.dtype(), - std::make_shared( - to_stream(s, Device::cpu), group, AllReduce::Max), + std::make_shared(stream, group, AllReduce::Max), {x}); } @@ -62,11 +67,12 @@ array all_min( if (group.size() == 1) { return x; } + auto stream = detail::communication_stream(group, s); + return array( x.shape(), x.dtype(), - std::make_shared( - to_stream(s, Device::cpu), group, AllReduce::Min), + std::make_shared(stream, group, AllReduce::Min), {x}); } @@ -79,6 +85,7 @@ array all_gather( if (group.size() == 1) { return x; } + auto stream = detail::communication_stream(group, s); auto result_shape = x.shape(); if (result_shape.size() == 0) { @@ -89,7 +96,7 @@ array all_gather( return array( std::move(result_shape), x.dtype(), - std::make_shared(to_stream(s, Device::cpu), group), + std::make_shared(stream, group), {x}); } @@ -103,6 +110,7 @@ array send( if (group.size() == 1) { throw std::invalid_argument("Cannot send to a singleton group"); } + auto stream = detail::communication_stream(group, s); if (dst < 0 || dst >= group.size()) { std::ostringstream msg; @@ -112,10 +120,7 @@ array send( } return array( - x.shape(), - x.dtype(), - std::make_shared(to_stream(s, Device::cpu), group, dst), - {x}); + x.shape(), x.dtype(), std::make_shared(stream, group, dst), {x}); } array recv( @@ -129,6 +134,7 @@ array recv( if (group.size() == 1) { throw std::invalid_argument("Cannot recv from a singleton group"); } + auto stream = detail::communication_stream(group, s); if (src < 0 || src >= group.size()) { std::ostringstream msg; @@ -139,7 +145,7 @@ array recv( return array( std::move(shape), std::move(dtype), - std::make_shared(to_stream(s, Device::cpu), group, src), + std::make_shared(stream, group, src), std::vector{}); } diff --git a/mlx/distributed/ring/ring.cpp b/mlx/distributed/ring/ring.cpp index b31274e23..7c3dcf095 100644 --- a/mlx/distributed/ring/ring.cpp +++ b/mlx/distributed/ring/ring.cpp @@ -619,6 +619,10 @@ class RingGroup : public GroupImpl { } } + Stream communication_stream(StreamOrDevice s) override { + return to_stream(s, Device::cpu); + } + int rank() override { return rank_; } diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index 404ecc349..afd8b5130 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -415,6 +415,48 @@ def launch_mpi(parser, hosts, args, command): pass +def launch_nccl(parser, hosts, args, command): + master_host = hosts[0].ips[0] + + if master_host != "127.0.0.1": + raise ValueError("The NCCL backend only supports localhost for now. ") + master_port = args.nccl_port + world_size = len(hosts) + + base_env = os.environ.copy() + base_env.update( + { + "NCCL_DEBUG": "INFO", + "NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication + "NCCL_HOST_IP": master_host, + "NCCL_PORT": str(master_port), + "MLX_WORLD_SIZE": str(world_size), + } + ) + procs = [] + try: + for rank in range(world_size): + env = base_env.copy() + env["MLX_RANK"] = str(rank) + env["CUDA_VISIBLE_DEVICES"] = str(rank % args.nproc_per_node) + p = Popen(command, env=env) + procs.append(p) + + for p in procs: + ret = p.wait() + if ret != 0: + raise RuntimeError(f"Rank process exited with {ret}") + + except (RuntimeError, KeyboardInterrupt) as err: + for p in procs: + if p.poll() is None: + try: + p.kill() + except Exception: + pass + raise + + def check_ssh_connections(hosts): results = [False] * len(hosts) @@ -665,7 +707,7 @@ def distributed_config(): ) parser.add_argument( "--backend", - choices=["ring", "mpi"], + choices=["ring", "mpi", "nccl"], default="ring", help="Which distributed backend to configure", ) @@ -737,7 +779,7 @@ def main(): parser.add_argument("--hostfile", help="The file containing the hosts") parser.add_argument( "--backend", - choices=["ring", "mpi"], + choices=["ring", "mpi", "nccl"], default="ring", help="Which distributed backend to launch", ) @@ -769,6 +811,13 @@ def main(): parser.add_argument( "--cwd", help="Set the working directory on each node to the provided one" ) + parser.add_argument( + "--nccl-port", + type=int, + default=12345, + help="The port to use for the NCCL communication (only for nccl backend)", + ) + args, rest = parser.parse_known_args() if rest[0] == "--": rest.pop(0) @@ -799,8 +848,10 @@ def main(): # Launch if args.backend == "ring": launch_ring(parser, hosts, args, rest) - elif args.backend == "mpi": + if args.backend == "mpi": launch_mpi(parser, hosts, args, rest) + if args.backend == "nccl": + launch_nccl(parser, hosts, args, rest) if __name__ == "__main__": diff --git a/python/mlx/nn/utils.py b/python/mlx/nn/utils.py index 6cc799a7c..8c786454f 100644 --- a/python/mlx/nn/utils.py +++ b/python/mlx/nn/utils.py @@ -76,6 +76,7 @@ def average_gradients( group: Optional[mx.distributed.Group] = None, all_reduce_size: int = 32 * 1024**2, communication_type: Optional[mx.Dtype] = None, + stream: mx.Stream = mx.cpu, ): """Average the gradients across the distributed processes in the passed group. @@ -94,6 +95,7 @@ def average_gradients( communication_type (Optional[mlx.core.Dtype]): If provided cast to this type before performing the communication. Typically cast to a smaller float to reduce the communication size. Default: ``None``. + stream (mlx.core.Stream): The stream to use for the reduction. Default: ``mlx.cpu``. """ group = group or mx.distributed.init() N = group.size() @@ -104,7 +106,7 @@ def average_gradients( def _average(x): dt = x.dtype x = x.astype(communication_type) if communication_type is not None else x - return mx.distributed.all_sum(x, stream=mx.cpu).astype(dt) / N + return mx.distributed.all_sum(x, stream=stream).astype(dt) / N if all_reduce_size <= 0: return tree_map(_average, gradients) diff --git a/python/src/distributed.cpp b/python/src/distributed.cpp index b52fa86c0..e2e191dbb 100644 --- a/python/src/distributed.cpp +++ b/python/src/distributed.cpp @@ -79,7 +79,7 @@ void init_distributed(nb::module_& parent_module) { in case ``mx.distributed.is_available()`` returns False otherwise it throws a runtime error. Default: ``False`` backend (str, optional): Which distributed backend to initialize. - Possible values ``mpi``, ``ring``, ``any``. If set to ``any`` all + Possible values ``mpi``, ``ring``, ``nccl``, ``any``. If set to ``any`` all available backends are tried and the first one that succeeds becomes the global group which will be returned in subsequent calls. Default: ``any`` diff --git a/python/tests/nccl_test_distributed.py b/python/tests/nccl_test_distributed.py new file mode 100644 index 000000000..c55fb5c1f --- /dev/null +++ b/python/tests/nccl_test_distributed.py @@ -0,0 +1,284 @@ +# Copyright © 2024 Apple Inc. +import mlx.core as mx +import mlx.nn as nn +import mlx_tests +from mlx.nn.layers.distributed import shard_inplace, shard_linear +from mlx.nn.utils import average_gradients + + +class TestNCCLDistributed(mlx_tests.MLXTestCase): + @classmethod + def setUpClass(cls): + world = mx.distributed.init(strict=True, backend="nccl") + rank = world.rank() + mx.set_default_device(mx.Device(mx.gpu, rank % 8)) + + def test_all_reduce(self): + world = mx.distributed.init() + dtypes = [ + (mx.int8, 0), + (mx.uint8, 0), + (mx.int32, 0), + (mx.uint32, 0), + (mx.float32, 1e-6), + (mx.float16, 5e-3), + (mx.bfloat16, 1e-1), + ] + sizes = [ + (7,), + (10,), + (1024,), + (1024, 1024), + ] + key = mx.random.key(0) + + for dt, rtol in dtypes: + for sh in sizes: + x = ( + mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10 + ).astype(dt) + + # All sum + y = mx.distributed.all_sum(x[world.rank()]) + z = x.sum(0) + maxrelerror = (y - z).abs() + if rtol > 0: + maxrelerror /= z.abs() + maxrelerror = maxrelerror.max() + self.assertLessEqual(maxrelerror, rtol) + + def test_average_gradients(self): + original_all_sum = mx.distributed.all_sum + n_calls = 0 + xtype = None + + def new_all_sum(x, **kwargs): + nonlocal n_calls + nonlocal xtype + + n_calls += 1 + if xtype is not None: + self.assertEqual(xtype, x.dtype) + + return original_all_sum(x, **kwargs) + + mx.distributed.all_sum = new_all_sum + try: + grads = [mx.ones(10) for i in range(10)] + new_grads = average_gradients(grads, stream=mx.gpu) + mx.eval(new_grads) + self.assertEqual(len(new_grads), 10) + self.assertTrue(all(mx.all(g == 1) for g in new_grads)) + self.assertEqual(n_calls, 1) + + n_calls = 0 + new_grads = average_gradients(grads, all_reduce_size=4 * 50, stream=mx.gpu) + mx.eval(new_grads) + self.assertEqual(len(new_grads), 10) + self.assertTrue(all(mx.all(g == 1) for g in new_grads)) + self.assertEqual(n_calls, 2) + + n_calls = 0 + new_grads = average_gradients(grads, all_reduce_size=0, stream=mx.gpu) + mx.eval(new_grads) + self.assertEqual(len(new_grads), 10) + self.assertTrue(all(mx.all(g == 1) for g in new_grads)) + self.assertEqual(n_calls, 10) + + n_calls = 0 + xtype = mx.float16 + new_grads = average_gradients( + grads, + all_reduce_size=2 * 50, + communication_type=mx.float16, + stream=mx.gpu, + ) + mx.eval(new_grads) + self.assertEqual(len(new_grads), 10) + self.assertTrue(all(g.dtype == mx.float32 for g in new_grads)) + self.assertTrue(all(mx.all(g == 1) for g in new_grads)) + self.assertEqual(n_calls, 2) + + finally: + mx.distributed.all_sum = original_all_sum + + def test_donation(self): + x = mx.random.normal((1024,)) + mx.eval(x) + mx.synchronize() + + mx.reset_peak_memory() + scale = mx.array(2.0) + y = mx.distributed.all_sum(x) + mx.eval(y) + mx.synchronize() + all_sum_only = mx.get_peak_memory() + y = mx.distributed.all_sum(x) * scale + mx.eval(y) + mx.synchronize() + all_sum_with_binary = mx.get_peak_memory() + + self.assertEqual(all_sum_only, all_sum_with_binary) + + def test_shard_linear(self): + # Seed the prng to have the same inputs and weights generated everywhere + mx.random.seed(0xF0F0F0F0) + + # Prepare inputs + world = mx.distributed.init() + part = ( + slice(None), + slice( + world.rank() * 1024 // world.size(), + (world.rank() + 1) * 1024 // world.size(), + ), + ) + x = mx.random.normal((4, 1024)) + + # Create and shard some linear layers + lin = nn.Linear(1024, 1024, bias=True) + slin1 = shard_linear(lin, "all-to-sharded") + slin2 = shard_linear(lin, "sharded-to-all") + y = lin(x) + y1 = slin1(x) + y2 = slin2(x[part]) + self.assertTrue(mx.allclose(y, y2, atol=1e-4, rtol=1e-4)) + self.assertTrue(mx.allclose(y[part], y1, atol=1e-4, rtol=1e-4)) + + # Check the backward works as expected + def dummy_loss(model, x, y): + return (model(x) * y).sum() + + mod = nn.Sequential( + nn.Linear(128, 128), + nn.Linear(128, 128), + nn.Linear(128, 128), + nn.Linear(128, 128), + ) + smod = nn.Sequential( + shard_linear(mod.layers[0], "all-to-sharded"), + shard_linear(mod.layers[1], "sharded-to-all"), + shard_linear(mod.layers[2], "all-to-sharded"), + shard_linear(mod.layers[3], "sharded-to-all"), + ) + + grad1 = nn.value_and_grad(mod, dummy_loss) + grad2 = nn.value_and_grad(smod, dummy_loss) + + x = mx.random.normal((4, 128)) + y = mx.random.normal((4, 128)) + + l1, g1 = grad1(mod, x, y) + l2, g2 = grad2(smod, x, y) + mx.eval(l1, g1, l2, g2) + + part = slice( + world.rank() * 128 // world.size(), (world.rank() + 1) * 128 // world.size() + ) + + self.assertTrue(mx.allclose(l1, l2)) + self.assertTrue( + mx.allclose( + g1["layers"][0]["weight"][part], + g2["layers"][0]["weight"], + atol=1e-6, + rtol=1e-4, + ) + ) + self.assertTrue( + mx.allclose( + g1["layers"][2]["weight"][part], + g2["layers"][2]["weight"], + atol=1e-6, + rtol=1e-4, + ) + ) + self.assertTrue( + mx.allclose( + g1["layers"][1]["weight"][:, part], + g2["layers"][1]["weight"], + atol=1e-6, + rtol=1e-4, + ) + ) + self.assertTrue( + mx.allclose( + g1["layers"][3]["weight"][:, part], + g2["layers"][3]["weight"], + atol=1e-6, + rtol=1e-4, + ) + ) + self.assertTrue( + mx.allclose( + g1["layers"][0]["bias"][part], + g2["layers"][0]["bias"], + atol=1e-6, + rtol=1e-4, + ) + ) + self.assertTrue( + mx.allclose( + g1["layers"][2]["bias"][part], + g2["layers"][2]["bias"], + atol=1e-6, + rtol=1e-4, + ) + ) + self.assertTrue( + mx.allclose( + g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4 + ) + ) + self.assertTrue( + mx.allclose( + g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4 + ) + ) + + def test_shard_predicate(self): + mx.random.seed(0xF0F0F0F0) + + class MyConv(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + self.aggregate = kwargs.pop("aggregate", False) + self.conv = nn.Conv2d(*args, **kwargs) + + def __call__(self, x): + x = self.conv(x) + if self.aggregate: + x = mx.distributed.all_sum(x) + return x + + def sharding(path, weight): + parts = path.split(".") + even = int(parts[1]) % 2 == 0 + if even: + return 0 + else: + return -1 if parts[-1] != "bias" else None + + mod = nn.Sequential( + MyConv(3, 128, kernel_size=3), + MyConv(128, 128, kernel_size=3), + MyConv(128, 128, kernel_size=3), + MyConv(128, 3, kernel_size=3), + ) + smod = nn.Sequential( + MyConv(3, 128, kernel_size=3), + MyConv(128, 128, kernel_size=3, aggregate=True), + MyConv(128, 128, kernel_size=3), + MyConv(128, 3, kernel_size=3, aggregate=True), + ) + smod.update(mod.parameters()) + shard_inplace(smod, sharding) + + x = mx.random.normal((4, 16, 16, 3)) + y1 = mod(x) + y2 = smod(x) + self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4)) + + +if __name__ == "__main__": + mlx_tests.MLXTestRunner()