NCCL backend (#2476)

This commit is contained in:
Anastasiia Filippova 2025-08-21 20:56:15 +02:00 committed by GitHub
parent e843c4d8d5
commit 9392fc3f88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 897 additions and 20 deletions

View File

@ -222,6 +222,7 @@ jobs:
sudo apt-get update sudo apt-get update
sudo apt-get install libcudnn9-dev-cuda-12 sudo apt-get install libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install libnccl2 libnccl-dev
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf - curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
rm -rf ccache-4.11.3-linux-x86_64 rm -rf ccache-4.11.3-linux-x86_64

54
cmake/FindNCCL.cmake Normal file
View File

@ -0,0 +1,54 @@
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
# directories.
set(NCCL_ROOT_DIR
$ENV{NCCL_ROOT_DIR}
CACHE PATH "Folder contains NVIDIA NCCL")
find_path(
NCCL_INCLUDE_DIRS
NAMES nccl.h
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
${CUDA_TOOLKIT_ROOT_DIR}/include)
if($ENV{USE_STATIC_NCCL})
message(
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
set(NCCL_LIBNAME "libnccl_static.a")
else()
set(NCCL_LIBNAME "nccl")
endif()
find_library(
NCCL_LIBRARIES
NAMES ${NCCL_LIBNAME}
HINTS ${NCCL_LIB_DIR}
${NCCL_ROOT_DIR}
${NCCL_ROOT_DIR}/lib
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
${NCCL_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
NCCL_LIBRARIES)
if(NCCL_FOUND)
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
message(
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
file(
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
LIMIT_COUNT 1)
if(NCCL_MAJOR_VERSION_DEFINED)
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
endif()
message(
STATUS
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
endif()

View File

@ -271,7 +271,7 @@ and the CUDA toolkit. For example on Ubuntu, run the following:
dpkg -i cuda-keyring_1.1-1_all.deb dpkg -i cuda-keyring_1.1-1_all.deb
apt-get update -y apt-get update -y
apt-get -y install cuda-toolkit-12-9 apt-get -y install cuda-toolkit-12-9
apt-get install libblas-dev liblapack-dev liblapacke-dev -y apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
When building either the Python or C++ APIs make sure to pass the cmake flag When building either the Python or C++ APIs make sure to pass the cmake flag

View File

@ -22,6 +22,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp ${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp

View File

@ -0,0 +1,51 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/distributed/primitives.h"
#include "mlx/primitives.h"
#include <cassert>
namespace mlx::core {
namespace distributed {
void AllReduce::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto& input = inputs[0];
auto& output = outputs[0];
auto& encoder = cu::get_command_encoder(stream());
if (input.is_donatable()) {
output.copy_shared_buffer(input);
} else {
output.set_data(allocator::malloc(output.nbytes()));
}
encoder.set_input_array(input);
encoder.set_output_array(output);
auto capture = encoder.capture_context();
auto& s = stream();
switch (reduce_type_) {
case Sum:
distributed::detail::all_sum(group(), input, output, s);
break;
case Max:
distributed::detail::all_max(group(), input, output, s);
break;
case Min:
distributed::detail::all_min(group(), input, output, s);
break;
default:
throw std::runtime_error(
"Only all reduce sum, max, and min are supported.");
}
}
} // namespace distributed
} // namespace mlx::core

View File

@ -42,7 +42,6 @@ NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh) NO_GPU_MULTI(Eigh)
namespace distributed { namespace distributed {
NO_GPU_MULTI(AllReduce)
NO_GPU_MULTI(AllGather) NO_GPU_MULTI(AllGather)
NO_GPU_MULTI(Send) NO_GPU_MULTI(Send)
NO_GPU_MULTI(Recv) NO_GPU_MULTI(Recv)

View File

@ -6,3 +6,4 @@ target_sources(
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)

View File

@ -5,12 +5,17 @@
#include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/mpi/mpi.h" #include "mlx/distributed/mpi/mpi.h"
#include "mlx/distributed/nccl/nccl.h"
#include "mlx/distributed/ring/ring.h" #include "mlx/distributed/ring/ring.h"
namespace mlx::core::distributed { namespace mlx::core::distributed {
namespace detail { namespace detail {
Stream communication_stream(Group group, StreamOrDevice s /* = {} */) {
return group.raw_group()->communication_stream(s);
}
void all_sum(Group group, const array& input, array& output, Stream stream) { void all_sum(Group group, const array& input, array& output, Stream stream) {
group.raw_group()->all_sum(input, output, stream); group.raw_group()->all_sum(input, output, stream);
} }
@ -37,6 +42,10 @@ void recv(Group group, array& out, int src, Stream stream) {
class EmptyGroup : public GroupImpl { class EmptyGroup : public GroupImpl {
public: public:
Stream communication_stream(StreamOrDevice s) override {
return to_stream(s);
}
int rank() override { int rank() override {
return 0; return 0;
} }
@ -80,7 +89,7 @@ class EmptyGroup : public GroupImpl {
} // namespace detail } // namespace detail
bool is_available() { bool is_available() {
return mpi::is_available() || ring::is_available(); return mpi::is_available() || ring::is_available() || nccl::is_available();
} }
int Group::rank() const { int Group::rank() const {
@ -111,6 +120,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
group = mpi::init(strict); group = mpi::init(strict);
} else if (bk == "ring") { } else if (bk == "ring") {
group = ring::init(strict); group = ring::init(strict);
} else if (bk == "nccl") {
group = nccl::init(strict);
} else if (bk == "any") { } else if (bk == "any") {
group = ring::init(false); group = ring::init(false);
bk_ = "ring"; bk_ = "ring";

View File

@ -5,6 +5,7 @@
#include <memory> #include <memory>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/utils.h"
namespace mlx::core::distributed { namespace mlx::core::distributed {

View File

@ -13,10 +13,15 @@ class GroupImpl {
public: public:
virtual ~GroupImpl() {} virtual ~GroupImpl() {}
// Choose the stream this communication group can operate on
virtual Stream communication_stream(StreamOrDevice s = {}) = 0;
// Group operations
virtual int rank() = 0; virtual int rank() = 0;
virtual int size() = 0; virtual int size() = 0;
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0; virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;
// Actual communication operations
virtual void all_sum(const array& input, array& output, Stream stream) = 0; virtual void all_sum(const array& input, array& output, Stream stream) = 0;
virtual void all_gather(const array& input, array& output, Stream stream) = 0; virtual void all_gather(const array& input, array& output, Stream stream) = 0;
virtual void send(const array& input, int dst, Stream stream) = 0; virtual void send(const array& input, int dst, Stream stream) = 0;
@ -25,6 +30,9 @@ class GroupImpl {
virtual void all_min(const array& input, array& output, Stream stream) = 0; virtual void all_min(const array& input, array& output, Stream stream) = 0;
}; };
/* Define the MLX stream that the communication should happen in. */
Stream communication_stream(Group group, StreamOrDevice s = {});
/* Perform an all reduce sum operation */ /* Perform an all reduce sum operation */
void all_sum(Group group, const array& input, array& output, Stream stream); void all_sum(Group group, const array& input, array& output, Stream stream);

View File

@ -349,6 +349,10 @@ class MPIGroup : public GroupImpl {
} }
} }
Stream communication_stream(StreamOrDevice s) override {
return to_stream(s, Device::cpu);
}
int rank() override { int rank() override {
if (rank_ < 0) { if (rank_ < 0) {
mpi().rank(comm_, &rank_); mpi().rank(comm_, &rank_);

View File

@ -0,0 +1,8 @@
if(MLX_BUILD_CUDA)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nccl.cpp)
find_package(NCCL REQUIRED)
target_link_libraries(mlx PRIVATE ${NCCL_LIBRARIES})
target_include_directories(mlx PRIVATE ${NCCL_INCLUDE_DIRS})
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_nccl.cpp)
endif()

View File

@ -0,0 +1,359 @@
#include <arpa/inet.h>
#include <cuda_runtime.h>
#include <nccl.h>
#include <netdb.h>
#include <sys/socket.h>
#include <unistd.h>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <mutex>
#include <stdexcept>
#include <string>
#include <type_traits>
#include "mlx/backend/cuda/device.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/dtype_utils.h"
#include "mlx/utils.h"
namespace mlx::core::distributed::nccl {
#define CHECK_CUDA(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
fprintf( \
stderr, \
"CUDA error %s:%d '%s'\n", \
__FILE__, \
__LINE__, \
cudaGetErrorString(e)); \
exit(1); \
} \
} while (0)
#define CHECK_NCCL(cmd) \
do { \
ncclResult_t r = cmd; \
if (r != ncclSuccess) { \
fprintf( \
stderr, \
"NCCL error %s:%d '%s'\n", \
__FILE__, \
__LINE__, \
ncclGetErrorString(r)); \
exit(1); \
} \
} while (0)
#define MLX_NCCL_TYPE_LIST(X) \
X(int8_t, ncclChar) \
X(uint8_t, ncclUint8) \
X(int32_t, ncclInt) \
X(uint32_t, ncclUint32) \
X(int64_t, ncclInt64) \
X(uint64_t, ncclUint64) \
X(float16_t, ncclHalf) \
X(bfloat16_t, ncclBfloat16) \
X(float, ncclFloat) \
X(double, ncclDouble)
template <class>
struct nccl_map {
static constexpr bool ok = false; // default: unsupported
};
#define MLX_DEF_NCCL_MAP(T, E) \
template <> \
struct nccl_map<T> { \
static constexpr bool ok = true; \
static constexpr ncclDataType_t value = E; \
};
MLX_NCCL_TYPE_LIST(MLX_DEF_NCCL_MAP)
#undef MLX_DEF_NCCL_MAP
namespace detail {
template <typename F>
void dispatch_dtype(const array& arr, F&& f) {
dispatch_all_types(arr.dtype(), [&](auto type_tag) {
using T = MLX_GET_TYPE(type_tag);
if constexpr (nccl_map<T>::ok) {
f(type_tag, nccl_map<T>::value);
} else {
throw std::invalid_argument("[nccl] Unknown or unsupported dtype");
}
});
}
inline void sendAll(int sock, const void* buf, size_t len) {
const char* ptr = reinterpret_cast<const char*>(buf);
while (len > 0) {
ssize_t sent = send(sock, ptr, len, 0);
if (sent <= 0) {
perror("send");
exit(1);
}
ptr += sent;
len -= sent;
}
}
inline void recvAll(int sock, void* buf, size_t len) {
char* ptr = reinterpret_cast<char*>(buf);
while (len > 0) {
ssize_t rec = recv(sock, ptr, len, 0);
if (rec <= 0) {
perror("recv");
exit(1);
}
ptr += rec;
len -= rec;
}
}
inline void bootstrap_unique_id(
ncclUniqueId& id,
int rank,
int size,
const std::string& initMethod) {
// Parse the init method to extract the host and port
if (initMethod.rfind("tcp://", 0) != 0)
throw;
auto hostport = initMethod.substr(6);
auto colon = hostport.find(':');
std::string host = hostport.substr(0, colon);
int port = std::stoi(hostport.substr(colon + 1));
if (rank == 0) {
// create a unique id on the rank 0
CHECK_NCCL(ncclGetUniqueId(&id));
// create a socket to send the unique id to all other ranks
int sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << "[nccl] Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
sockaddr_in serv = {};
serv.sin_family = AF_INET;
serv.sin_addr.s_addr = htonl(INADDR_ANY);
serv.sin_port = htons(port);
int reuse = 1;
// Without this, if rank-0 crashes or restarts process quickly,
// the OS might refuse to let binding to the same port, so reuse
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) {
std::ostringstream msg;
msg << "[nccl] setsockopt() failed: " << strerror(errno);
throw std::runtime_error(msg.str());
}
if (bind(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) < 0) {
std::ostringstream msg;
msg << "[nccl] bind() failed: " << strerror(errno);
throw std::runtime_error(msg.str());
}
if (listen(sock, size - 1) < 0) {
std::ostringstream msg;
msg << "[nccl] listen() failed: " << strerror(errno);
throw std::runtime_error(msg.str());
}
for (int peer = 1; peer < size; ++peer) {
int conn = accept(sock, nullptr, nullptr);
if (conn < 0) {
std::ostringstream msg;
msg << "[nccl] accept() failed: " << strerror(errno);
throw std::runtime_error(msg.str());
}
sendAll(conn, &id, sizeof(id));
close(conn);
}
close(sock);
} else {
// Here just wanted to make show that rank 0 has enough time to bind
// so we will retry to connect until max attempts
int sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << "[nccl] socket() failed: " << strerror(errno);
throw std::runtime_error(msg.str());
}
hostent* he = gethostbyname(host.c_str());
if (!he) {
throw std::runtime_error("[nccl] lookup failed for host: " + host);
}
sockaddr_in serv = {};
serv.sin_family = AF_INET;
memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length);
serv.sin_port = htons(port);
const int max_retries = 30;
int attempt = 0;
bool connected = false;
for (attempt = 0; attempt < max_retries; ++attempt) {
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
0) {
connected = true;
std::cout << "[Rank " << rank << "] Connected successfully on attempt "
<< attempt + 1 << std::endl;
break;
}
if (errno != ECONNREFUSED) {
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(500));
}
if (!connected) {
std::ostringstream msg;
msg << "[Rank " << rank << "] connect() failed after " << attempt
<< " retries: " << strerror(errno);
close(sock);
throw std::runtime_error(msg.str());
}
recvAll(sock, &id, sizeof(id));
close(sock);
}
}
} // namespace detail
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
class NCCLGroup : public GroupImpl {
public:
NCCLGroup(int worldRank, int worldSize, const std::string initMethod)
: rank_(worldRank),
size_(worldSize),
comm_(nullptr),
initMethod_(initMethod) {
if (initialized_)
return;
int ndev;
CHECK_CUDA(cudaGetDeviceCount(&ndev));
CHECK_CUDA(cudaSetDevice(rank_ % ndev));
detail::bootstrap_unique_id(uniqueId_, rank_, size_, initMethod_);
CHECK_NCCL(ncclCommInitRank(&comm_, size_, uniqueId_, rank_));
initialized_ = true;
}
~NCCLGroup() {
ncclCommDestroy(comm_);
ncclGroupEnd();
initialized_ = false;
}
Stream communication_stream(StreamOrDevice s) override {
return to_stream(s, Device::gpu);
}
int rank() override {
return rank_;
}
int size() override {
return size_;
}
void all_sum(const array& input, array& output, Stream stream) override {
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
all_reduce_impl<T>(input, output, stream, dt, ncclSum);
});
}
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
throw std::runtime_error("[nccl] Group split not supported.");
}
void all_gather(const array& input, array& output, Stream stream) override {
throw std::runtime_error(
"[nccl] All gather not supported in NCCL backend.");
}
void send(const array& input, int dst, Stream stream) override {
throw std::runtime_error("[nccl] Send not supported in NCCL backend.");
}
void recv(array& output, int src, Stream stream) override {
throw std::runtime_error("[nccl] Recv not supported in NCCL backend.");
}
void all_max(const array& input, array& output, Stream stream) override {
throw std::runtime_error("[nccl] All max not supported in NCCL backend.");
}
void all_min(const array& input, array& output, Stream stream) override {
throw std::runtime_error("[nccl] All min not supported in NCCL backend.");
}
template <typename T>
void all_reduce_impl(
const array& input,
array& output,
Stream stream,
ncclDataType_t dt,
ncclRedOp_t op) {
auto& encoder = cu::get_command_encoder(stream);
CHECK_NCCL(ncclAllReduce(
input.data<T>(),
output.data<T>(),
input.size(),
dt,
op,
comm_,
encoder.stream()));
}
int rank_, size_;
std::string initMethod_;
ncclUniqueId uniqueId_;
ncclComm_t comm_;
bool initialized_ = false;
};
bool is_available() {
return true;
}
namespace detail {
static std::string get_env_var_or_throw(const char* env_var_name) {
const char* value = std::getenv(env_var_name);
if (value == nullptr) {
std::ostringstream msg;
msg << "[nccl] Required environment variable '" << env_var_name
<< "' is not set. "
<< "Please set it before initializing the distributed backend.";
throw std::runtime_error(msg.str());
}
return std::string(value);
}
} // namespace detail
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
std::string host = detail::get_env_var_or_throw("NCCL_HOST_IP");
std::string port = detail::get_env_var_or_throw("NCCL_PORT");
std::string rank_str = detail::get_env_var_or_throw("MLX_RANK");
std::string n_nodes_str = detail::get_env_var_or_throw("MLX_WORLD_SIZE");
int rank = std::stoi(rank_str);
int n_nodes = std::stoi(n_nodes_str);
std::string init_method = "tcp://" + host + ":" + port;
return std::make_shared<NCCLGroup>(rank, n_nodes, init_method);
}
} // namespace mlx::core::distributed::nccl

View File

@ -0,0 +1,12 @@
// Copyright © 2024 Apple Inc.
#include "mlx/distributed/distributed.h"
namespace mlx::core::distributed::nccl {
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
bool is_available();
std::shared_ptr<GroupImpl> init(bool strict = false);
} // namespace mlx::core::distributed::nccl

View File

@ -0,0 +1,20 @@
// Copyright © 2024 Apple Inc.
#include "mlx/distributed/nccl/nccl.h"
namespace mlx::core::distributed::nccl {
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
bool is_available() {
return false;
}
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
if (strict) {
throw std::runtime_error("Cannot initialize nccl distributed backend.");
}
return nullptr;
}
} // namespace mlx::core::distributed::nccl

View File

@ -2,6 +2,9 @@
#include <sstream> #include <sstream>
#include "mlx/backend/cuda/cuda.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/ops.h" #include "mlx/distributed/ops.h"
#include "mlx/distributed/primitives.h" #include "mlx/distributed/primitives.h"
@ -28,11 +31,12 @@ array all_sum(
if (group.size() == 1) { if (group.size() == 1) {
return x; return x;
} }
auto stream = detail::communication_stream(group, s);
return array( return array(
x.shape(), x.shape(),
x.dtype(), x.dtype(),
std::make_shared<AllReduce>( std::make_shared<AllReduce>(stream, group, AllReduce::Sum),
to_stream(s, Device::cpu), group, AllReduce::Sum),
{x}); {x});
} }
@ -45,11 +49,12 @@ array all_max(
if (group.size() == 1) { if (group.size() == 1) {
return x; return x;
} }
auto stream = detail::communication_stream(group, s);
return array( return array(
x.shape(), x.shape(),
x.dtype(), x.dtype(),
std::make_shared<AllReduce>( std::make_shared<AllReduce>(stream, group, AllReduce::Max),
to_stream(s, Device::cpu), group, AllReduce::Max),
{x}); {x});
} }
@ -62,11 +67,12 @@ array all_min(
if (group.size() == 1) { if (group.size() == 1) {
return x; return x;
} }
auto stream = detail::communication_stream(group, s);
return array( return array(
x.shape(), x.shape(),
x.dtype(), x.dtype(),
std::make_shared<AllReduce>( std::make_shared<AllReduce>(stream, group, AllReduce::Min),
to_stream(s, Device::cpu), group, AllReduce::Min),
{x}); {x});
} }
@ -79,6 +85,7 @@ array all_gather(
if (group.size() == 1) { if (group.size() == 1) {
return x; return x;
} }
auto stream = detail::communication_stream(group, s);
auto result_shape = x.shape(); auto result_shape = x.shape();
if (result_shape.size() == 0) { if (result_shape.size() == 0) {
@ -89,7 +96,7 @@ array all_gather(
return array( return array(
std::move(result_shape), std::move(result_shape),
x.dtype(), x.dtype(),
std::make_shared<AllGather>(to_stream(s, Device::cpu), group), std::make_shared<AllGather>(stream, group),
{x}); {x});
} }
@ -103,6 +110,7 @@ array send(
if (group.size() == 1) { if (group.size() == 1) {
throw std::invalid_argument("Cannot send to a singleton group"); throw std::invalid_argument("Cannot send to a singleton group");
} }
auto stream = detail::communication_stream(group, s);
if (dst < 0 || dst >= group.size()) { if (dst < 0 || dst >= group.size()) {
std::ostringstream msg; std::ostringstream msg;
@ -112,10 +120,7 @@ array send(
} }
return array( return array(
x.shape(), x.shape(), x.dtype(), std::make_shared<Send>(stream, group, dst), {x});
x.dtype(),
std::make_shared<Send>(to_stream(s, Device::cpu), group, dst),
{x});
} }
array recv( array recv(
@ -129,6 +134,7 @@ array recv(
if (group.size() == 1) { if (group.size() == 1) {
throw std::invalid_argument("Cannot recv from a singleton group"); throw std::invalid_argument("Cannot recv from a singleton group");
} }
auto stream = detail::communication_stream(group, s);
if (src < 0 || src >= group.size()) { if (src < 0 || src >= group.size()) {
std::ostringstream msg; std::ostringstream msg;
@ -139,7 +145,7 @@ array recv(
return array( return array(
std::move(shape), std::move(shape),
std::move(dtype), std::move(dtype),
std::make_shared<Recv>(to_stream(s, Device::cpu), group, src), std::make_shared<Recv>(stream, group, src),
std::vector<array>{}); std::vector<array>{});
} }

View File

@ -619,6 +619,10 @@ class RingGroup : public GroupImpl {
} }
} }
Stream communication_stream(StreamOrDevice s) override {
return to_stream(s, Device::cpu);
}
int rank() override { int rank() override {
return rank_; return rank_;
} }

View File

@ -415,6 +415,48 @@ def launch_mpi(parser, hosts, args, command):
pass pass
def launch_nccl(parser, hosts, args, command):
master_host = hosts[0].ips[0]
if master_host != "127.0.0.1":
raise ValueError("The NCCL backend only supports localhost for now. ")
master_port = args.nccl_port
world_size = len(hosts)
base_env = os.environ.copy()
base_env.update(
{
"NCCL_DEBUG": "INFO",
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
"NCCL_HOST_IP": master_host,
"NCCL_PORT": str(master_port),
"MLX_WORLD_SIZE": str(world_size),
}
)
procs = []
try:
for rank in range(world_size):
env = base_env.copy()
env["MLX_RANK"] = str(rank)
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.nproc_per_node)
p = Popen(command, env=env)
procs.append(p)
for p in procs:
ret = p.wait()
if ret != 0:
raise RuntimeError(f"Rank process exited with {ret}")
except (RuntimeError, KeyboardInterrupt) as err:
for p in procs:
if p.poll() is None:
try:
p.kill()
except Exception:
pass
raise
def check_ssh_connections(hosts): def check_ssh_connections(hosts):
results = [False] * len(hosts) results = [False] * len(hosts)
@ -665,7 +707,7 @@ def distributed_config():
) )
parser.add_argument( parser.add_argument(
"--backend", "--backend",
choices=["ring", "mpi"], choices=["ring", "mpi", "nccl"],
default="ring", default="ring",
help="Which distributed backend to configure", help="Which distributed backend to configure",
) )
@ -737,7 +779,7 @@ def main():
parser.add_argument("--hostfile", help="The file containing the hosts") parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument( parser.add_argument(
"--backend", "--backend",
choices=["ring", "mpi"], choices=["ring", "mpi", "nccl"],
default="ring", default="ring",
help="Which distributed backend to launch", help="Which distributed backend to launch",
) )
@ -769,6 +811,13 @@ def main():
parser.add_argument( parser.add_argument(
"--cwd", help="Set the working directory on each node to the provided one" "--cwd", help="Set the working directory on each node to the provided one"
) )
parser.add_argument(
"--nccl-port",
type=int,
default=12345,
help="The port to use for the NCCL communication (only for nccl backend)",
)
args, rest = parser.parse_known_args() args, rest = parser.parse_known_args()
if rest[0] == "--": if rest[0] == "--":
rest.pop(0) rest.pop(0)
@ -799,8 +848,10 @@ def main():
# Launch # Launch
if args.backend == "ring": if args.backend == "ring":
launch_ring(parser, hosts, args, rest) launch_ring(parser, hosts, args, rest)
elif args.backend == "mpi": if args.backend == "mpi":
launch_mpi(parser, hosts, args, rest) launch_mpi(parser, hosts, args, rest)
if args.backend == "nccl":
launch_nccl(parser, hosts, args, rest)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -76,6 +76,7 @@ def average_gradients(
group: Optional[mx.distributed.Group] = None, group: Optional[mx.distributed.Group] = None,
all_reduce_size: int = 32 * 1024**2, all_reduce_size: int = 32 * 1024**2,
communication_type: Optional[mx.Dtype] = None, communication_type: Optional[mx.Dtype] = None,
stream: mx.Stream = mx.cpu,
): ):
"""Average the gradients across the distributed processes in the passed group. """Average the gradients across the distributed processes in the passed group.
@ -94,6 +95,7 @@ def average_gradients(
communication_type (Optional[mlx.core.Dtype]): If provided cast to this communication_type (Optional[mlx.core.Dtype]): If provided cast to this
type before performing the communication. Typically cast to a type before performing the communication. Typically cast to a
smaller float to reduce the communication size. Default: ``None``. smaller float to reduce the communication size. Default: ``None``.
stream (mlx.core.Stream): The stream to use for the reduction. Default: ``mlx.cpu``.
""" """
group = group or mx.distributed.init() group = group or mx.distributed.init()
N = group.size() N = group.size()
@ -104,7 +106,7 @@ def average_gradients(
def _average(x): def _average(x):
dt = x.dtype dt = x.dtype
x = x.astype(communication_type) if communication_type is not None else x x = x.astype(communication_type) if communication_type is not None else x
return mx.distributed.all_sum(x, stream=mx.cpu).astype(dt) / N return mx.distributed.all_sum(x, stream=stream).astype(dt) / N
if all_reduce_size <= 0: if all_reduce_size <= 0:
return tree_map(_average, gradients) return tree_map(_average, gradients)

View File

@ -79,7 +79,7 @@ void init_distributed(nb::module_& parent_module) {
in case ``mx.distributed.is_available()`` returns False otherwise in case ``mx.distributed.is_available()`` returns False otherwise
it throws a runtime error. Default: ``False`` it throws a runtime error. Default: ``False``
backend (str, optional): Which distributed backend to initialize. backend (str, optional): Which distributed backend to initialize.
Possible values ``mpi``, ``ring``, ``any``. If set to ``any`` all Possible values ``mpi``, ``ring``, ``nccl``, ``any``. If set to ``any`` all
available backends are tried and the first one that succeeds available backends are tried and the first one that succeeds
becomes the global group which will be returned in subsequent becomes the global group which will be returned in subsequent
calls. Default: ``any`` calls. Default: ``any``

View File

@ -0,0 +1,284 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
import mlx_tests
from mlx.nn.layers.distributed import shard_inplace, shard_linear
from mlx.nn.utils import average_gradients
class TestNCCLDistributed(mlx_tests.MLXTestCase):
@classmethod
def setUpClass(cls):
world = mx.distributed.init(strict=True, backend="nccl")
rank = world.rank()
mx.set_default_device(mx.Device(mx.gpu, rank % 8))
def test_all_reduce(self):
world = mx.distributed.init()
dtypes = [
(mx.int8, 0),
(mx.uint8, 0),
(mx.int32, 0),
(mx.uint32, 0),
(mx.float32, 1e-6),
(mx.float16, 5e-3),
(mx.bfloat16, 1e-1),
]
sizes = [
(7,),
(10,),
(1024,),
(1024, 1024),
]
key = mx.random.key(0)
for dt, rtol in dtypes:
for sh in sizes:
x = (
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
).astype(dt)
# All sum
y = mx.distributed.all_sum(x[world.rank()])
z = x.sum(0)
maxrelerror = (y - z).abs()
if rtol > 0:
maxrelerror /= z.abs()
maxrelerror = maxrelerror.max()
self.assertLessEqual(maxrelerror, rtol)
def test_average_gradients(self):
original_all_sum = mx.distributed.all_sum
n_calls = 0
xtype = None
def new_all_sum(x, **kwargs):
nonlocal n_calls
nonlocal xtype
n_calls += 1
if xtype is not None:
self.assertEqual(xtype, x.dtype)
return original_all_sum(x, **kwargs)
mx.distributed.all_sum = new_all_sum
try:
grads = [mx.ones(10) for i in range(10)]
new_grads = average_gradients(grads, stream=mx.gpu)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 1)
n_calls = 0
new_grads = average_gradients(grads, all_reduce_size=4 * 50, stream=mx.gpu)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 2)
n_calls = 0
new_grads = average_gradients(grads, all_reduce_size=0, stream=mx.gpu)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 10)
n_calls = 0
xtype = mx.float16
new_grads = average_gradients(
grads,
all_reduce_size=2 * 50,
communication_type=mx.float16,
stream=mx.gpu,
)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(g.dtype == mx.float32 for g in new_grads))
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 2)
finally:
mx.distributed.all_sum = original_all_sum
def test_donation(self):
x = mx.random.normal((1024,))
mx.eval(x)
mx.synchronize()
mx.reset_peak_memory()
scale = mx.array(2.0)
y = mx.distributed.all_sum(x)
mx.eval(y)
mx.synchronize()
all_sum_only = mx.get_peak_memory()
y = mx.distributed.all_sum(x) * scale
mx.eval(y)
mx.synchronize()
all_sum_with_binary = mx.get_peak_memory()
self.assertEqual(all_sum_only, all_sum_with_binary)
def test_shard_linear(self):
# Seed the prng to have the same inputs and weights generated everywhere
mx.random.seed(0xF0F0F0F0)
# Prepare inputs
world = mx.distributed.init()
part = (
slice(None),
slice(
world.rank() * 1024 // world.size(),
(world.rank() + 1) * 1024 // world.size(),
),
)
x = mx.random.normal((4, 1024))
# Create and shard some linear layers
lin = nn.Linear(1024, 1024, bias=True)
slin1 = shard_linear(lin, "all-to-sharded")
slin2 = shard_linear(lin, "sharded-to-all")
y = lin(x)
y1 = slin1(x)
y2 = slin2(x[part])
self.assertTrue(mx.allclose(y, y2, atol=1e-4, rtol=1e-4))
self.assertTrue(mx.allclose(y[part], y1, atol=1e-4, rtol=1e-4))
# Check the backward works as expected
def dummy_loss(model, x, y):
return (model(x) * y).sum()
mod = nn.Sequential(
nn.Linear(128, 128),
nn.Linear(128, 128),
nn.Linear(128, 128),
nn.Linear(128, 128),
)
smod = nn.Sequential(
shard_linear(mod.layers[0], "all-to-sharded"),
shard_linear(mod.layers[1], "sharded-to-all"),
shard_linear(mod.layers[2], "all-to-sharded"),
shard_linear(mod.layers[3], "sharded-to-all"),
)
grad1 = nn.value_and_grad(mod, dummy_loss)
grad2 = nn.value_and_grad(smod, dummy_loss)
x = mx.random.normal((4, 128))
y = mx.random.normal((4, 128))
l1, g1 = grad1(mod, x, y)
l2, g2 = grad2(smod, x, y)
mx.eval(l1, g1, l2, g2)
part = slice(
world.rank() * 128 // world.size(), (world.rank() + 1) * 128 // world.size()
)
self.assertTrue(mx.allclose(l1, l2))
self.assertTrue(
mx.allclose(
g1["layers"][0]["weight"][part],
g2["layers"][0]["weight"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][2]["weight"][part],
g2["layers"][2]["weight"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][1]["weight"][:, part],
g2["layers"][1]["weight"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][3]["weight"][:, part],
g2["layers"][3]["weight"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][0]["bias"][part],
g2["layers"][0]["bias"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][2]["bias"][part],
g2["layers"][2]["bias"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4
)
)
self.assertTrue(
mx.allclose(
g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4
)
)
def test_shard_predicate(self):
mx.random.seed(0xF0F0F0F0)
class MyConv(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.aggregate = kwargs.pop("aggregate", False)
self.conv = nn.Conv2d(*args, **kwargs)
def __call__(self, x):
x = self.conv(x)
if self.aggregate:
x = mx.distributed.all_sum(x)
return x
def sharding(path, weight):
parts = path.split(".")
even = int(parts[1]) % 2 == 0
if even:
return 0
else:
return -1 if parts[-1] != "bias" else None
mod = nn.Sequential(
MyConv(3, 128, kernel_size=3),
MyConv(128, 128, kernel_size=3),
MyConv(128, 128, kernel_size=3),
MyConv(128, 3, kernel_size=3),
)
smod = nn.Sequential(
MyConv(3, 128, kernel_size=3),
MyConv(128, 128, kernel_size=3, aggregate=True),
MyConv(128, 128, kernel_size=3),
MyConv(128, 3, kernel_size=3, aggregate=True),
)
smod.update(mod.parameters())
shard_inplace(smod, sharding)
x = mx.random.normal((4, 16, 16, 3))
y1 = mod(x)
y2 = smod(x)
self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4))
if __name__ == "__main__":
mlx_tests.MLXTestRunner()