NCCL backend (#2476)

This commit is contained in:
Anastasiia Filippova
2025-08-21 20:56:15 +02:00
committed by GitHub
parent e843c4d8d5
commit 9392fc3f88
21 changed files with 897 additions and 20 deletions

View File

@@ -22,6 +22,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp

View File

@@ -0,0 +1,51 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/distributed/primitives.h"
#include "mlx/primitives.h"
#include <cassert>
namespace mlx::core {
namespace distributed {
void AllReduce::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto& input = inputs[0];
auto& output = outputs[0];
auto& encoder = cu::get_command_encoder(stream());
if (input.is_donatable()) {
output.copy_shared_buffer(input);
} else {
output.set_data(allocator::malloc(output.nbytes()));
}
encoder.set_input_array(input);
encoder.set_output_array(output);
auto capture = encoder.capture_context();
auto& s = stream();
switch (reduce_type_) {
case Sum:
distributed::detail::all_sum(group(), input, output, s);
break;
case Max:
distributed::detail::all_max(group(), input, output, s);
break;
case Min:
distributed::detail::all_min(group(), input, output, s);
break;
default:
throw std::runtime_error(
"Only all reduce sum, max, and min are supported.");
}
}
} // namespace distributed
} // namespace mlx::core

View File

@@ -42,7 +42,6 @@ NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh)
namespace distributed {
NO_GPU_MULTI(AllReduce)
NO_GPU_MULTI(AllGather)
NO_GPU_MULTI(Send)
NO_GPU_MULTI(Recv)

View File

@@ -6,3 +6,4 @@ target_sources(
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)

View File

@@ -5,12 +5,17 @@
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/mpi/mpi.h"
#include "mlx/distributed/nccl/nccl.h"
#include "mlx/distributed/ring/ring.h"
namespace mlx::core::distributed {
namespace detail {
Stream communication_stream(Group group, StreamOrDevice s /* = {} */) {
return group.raw_group()->communication_stream(s);
}
void all_sum(Group group, const array& input, array& output, Stream stream) {
group.raw_group()->all_sum(input, output, stream);
}
@@ -37,6 +42,10 @@ void recv(Group group, array& out, int src, Stream stream) {
class EmptyGroup : public GroupImpl {
public:
Stream communication_stream(StreamOrDevice s) override {
return to_stream(s);
}
int rank() override {
return 0;
}
@@ -80,7 +89,7 @@ class EmptyGroup : public GroupImpl {
} // namespace detail
bool is_available() {
return mpi::is_available() || ring::is_available();
return mpi::is_available() || ring::is_available() || nccl::is_available();
}
int Group::rank() const {
@@ -111,6 +120,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
group = mpi::init(strict);
} else if (bk == "ring") {
group = ring::init(strict);
} else if (bk == "nccl") {
group = nccl::init(strict);
} else if (bk == "any") {
group = ring::init(false);
bk_ = "ring";

View File

@@ -5,6 +5,7 @@
#include <memory>
#include "mlx/array.h"
#include "mlx/utils.h"
namespace mlx::core::distributed {

View File

@@ -13,10 +13,15 @@ class GroupImpl {
public:
virtual ~GroupImpl() {}
// Choose the stream this communication group can operate on
virtual Stream communication_stream(StreamOrDevice s = {}) = 0;
// Group operations
virtual int rank() = 0;
virtual int size() = 0;
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;
// Actual communication operations
virtual void all_sum(const array& input, array& output, Stream stream) = 0;
virtual void all_gather(const array& input, array& output, Stream stream) = 0;
virtual void send(const array& input, int dst, Stream stream) = 0;
@@ -25,6 +30,9 @@ class GroupImpl {
virtual void all_min(const array& input, array& output, Stream stream) = 0;
};
/* Define the MLX stream that the communication should happen in. */
Stream communication_stream(Group group, StreamOrDevice s = {});
/* Perform an all reduce sum operation */
void all_sum(Group group, const array& input, array& output, Stream stream);

View File

@@ -349,6 +349,10 @@ class MPIGroup : public GroupImpl {
}
}
Stream communication_stream(StreamOrDevice s) override {
return to_stream(s, Device::cpu);
}
int rank() override {
if (rank_ < 0) {
mpi().rank(comm_, &rank_);

View File

@@ -0,0 +1,8 @@
if(MLX_BUILD_CUDA)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nccl.cpp)
find_package(NCCL REQUIRED)
target_link_libraries(mlx PRIVATE ${NCCL_LIBRARIES})
target_include_directories(mlx PRIVATE ${NCCL_INCLUDE_DIRS})
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_nccl.cpp)
endif()

View File

@@ -0,0 +1,359 @@
#include <arpa/inet.h>
#include <cuda_runtime.h>
#include <nccl.h>
#include <netdb.h>
#include <sys/socket.h>
#include <unistd.h>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <mutex>
#include <stdexcept>
#include <string>
#include <type_traits>
#include "mlx/backend/cuda/device.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/dtype_utils.h"
#include "mlx/utils.h"
namespace mlx::core::distributed::nccl {
#define CHECK_CUDA(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
fprintf( \
stderr, \
"CUDA error %s:%d '%s'\n", \
__FILE__, \
__LINE__, \
cudaGetErrorString(e)); \
exit(1); \
} \
} while (0)
#define CHECK_NCCL(cmd) \
do { \
ncclResult_t r = cmd; \
if (r != ncclSuccess) { \
fprintf( \
stderr, \
"NCCL error %s:%d '%s'\n", \
__FILE__, \
__LINE__, \
ncclGetErrorString(r)); \
exit(1); \
} \
} while (0)
#define MLX_NCCL_TYPE_LIST(X) \
X(int8_t, ncclChar) \
X(uint8_t, ncclUint8) \
X(int32_t, ncclInt) \
X(uint32_t, ncclUint32) \
X(int64_t, ncclInt64) \
X(uint64_t, ncclUint64) \
X(float16_t, ncclHalf) \
X(bfloat16_t, ncclBfloat16) \
X(float, ncclFloat) \
X(double, ncclDouble)
template <class>
struct nccl_map {
static constexpr bool ok = false; // default: unsupported
};
#define MLX_DEF_NCCL_MAP(T, E) \
template <> \
struct nccl_map<T> { \
static constexpr bool ok = true; \
static constexpr ncclDataType_t value = E; \
};
MLX_NCCL_TYPE_LIST(MLX_DEF_NCCL_MAP)
#undef MLX_DEF_NCCL_MAP
namespace detail {
template <typename F>
void dispatch_dtype(const array& arr, F&& f) {
dispatch_all_types(arr.dtype(), [&](auto type_tag) {
using T = MLX_GET_TYPE(type_tag);
if constexpr (nccl_map<T>::ok) {
f(type_tag, nccl_map<T>::value);
} else {
throw std::invalid_argument("[nccl] Unknown or unsupported dtype");
}
});
}
inline void sendAll(int sock, const void* buf, size_t len) {
const char* ptr = reinterpret_cast<const char*>(buf);
while (len > 0) {
ssize_t sent = send(sock, ptr, len, 0);
if (sent <= 0) {
perror("send");
exit(1);
}
ptr += sent;
len -= sent;
}
}
inline void recvAll(int sock, void* buf, size_t len) {
char* ptr = reinterpret_cast<char*>(buf);
while (len > 0) {
ssize_t rec = recv(sock, ptr, len, 0);
if (rec <= 0) {
perror("recv");
exit(1);
}
ptr += rec;
len -= rec;
}
}
inline void bootstrap_unique_id(
ncclUniqueId& id,
int rank,
int size,
const std::string& initMethod) {
// Parse the init method to extract the host and port
if (initMethod.rfind("tcp://", 0) != 0)
throw;
auto hostport = initMethod.substr(6);
auto colon = hostport.find(':');
std::string host = hostport.substr(0, colon);
int port = std::stoi(hostport.substr(colon + 1));
if (rank == 0) {
// create a unique id on the rank 0
CHECK_NCCL(ncclGetUniqueId(&id));
// create a socket to send the unique id to all other ranks
int sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << "[nccl] Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
sockaddr_in serv = {};
serv.sin_family = AF_INET;
serv.sin_addr.s_addr = htonl(INADDR_ANY);
serv.sin_port = htons(port);
int reuse = 1;
// Without this, if rank-0 crashes or restarts process quickly,
// the OS might refuse to let binding to the same port, so reuse
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) {
std::ostringstream msg;
msg << "[nccl] setsockopt() failed: " << strerror(errno);
throw std::runtime_error(msg.str());
}
if (bind(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) < 0) {
std::ostringstream msg;
msg << "[nccl] bind() failed: " << strerror(errno);
throw std::runtime_error(msg.str());
}
if (listen(sock, size - 1) < 0) {
std::ostringstream msg;
msg << "[nccl] listen() failed: " << strerror(errno);
throw std::runtime_error(msg.str());
}
for (int peer = 1; peer < size; ++peer) {
int conn = accept(sock, nullptr, nullptr);
if (conn < 0) {
std::ostringstream msg;
msg << "[nccl] accept() failed: " << strerror(errno);
throw std::runtime_error(msg.str());
}
sendAll(conn, &id, sizeof(id));
close(conn);
}
close(sock);
} else {
// Here just wanted to make show that rank 0 has enough time to bind
// so we will retry to connect until max attempts
int sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << "[nccl] socket() failed: " << strerror(errno);
throw std::runtime_error(msg.str());
}
hostent* he = gethostbyname(host.c_str());
if (!he) {
throw std::runtime_error("[nccl] lookup failed for host: " + host);
}
sockaddr_in serv = {};
serv.sin_family = AF_INET;
memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length);
serv.sin_port = htons(port);
const int max_retries = 30;
int attempt = 0;
bool connected = false;
for (attempt = 0; attempt < max_retries; ++attempt) {
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
0) {
connected = true;
std::cout << "[Rank " << rank << "] Connected successfully on attempt "
<< attempt + 1 << std::endl;
break;
}
if (errno != ECONNREFUSED) {
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(500));
}
if (!connected) {
std::ostringstream msg;
msg << "[Rank " << rank << "] connect() failed after " << attempt
<< " retries: " << strerror(errno);
close(sock);
throw std::runtime_error(msg.str());
}
recvAll(sock, &id, sizeof(id));
close(sock);
}
}
} // namespace detail
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
class NCCLGroup : public GroupImpl {
public:
NCCLGroup(int worldRank, int worldSize, const std::string initMethod)
: rank_(worldRank),
size_(worldSize),
comm_(nullptr),
initMethod_(initMethod) {
if (initialized_)
return;
int ndev;
CHECK_CUDA(cudaGetDeviceCount(&ndev));
CHECK_CUDA(cudaSetDevice(rank_ % ndev));
detail::bootstrap_unique_id(uniqueId_, rank_, size_, initMethod_);
CHECK_NCCL(ncclCommInitRank(&comm_, size_, uniqueId_, rank_));
initialized_ = true;
}
~NCCLGroup() {
ncclCommDestroy(comm_);
ncclGroupEnd();
initialized_ = false;
}
Stream communication_stream(StreamOrDevice s) override {
return to_stream(s, Device::gpu);
}
int rank() override {
return rank_;
}
int size() override {
return size_;
}
void all_sum(const array& input, array& output, Stream stream) override {
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
all_reduce_impl<T>(input, output, stream, dt, ncclSum);
});
}
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
throw std::runtime_error("[nccl] Group split not supported.");
}
void all_gather(const array& input, array& output, Stream stream) override {
throw std::runtime_error(
"[nccl] All gather not supported in NCCL backend.");
}
void send(const array& input, int dst, Stream stream) override {
throw std::runtime_error("[nccl] Send not supported in NCCL backend.");
}
void recv(array& output, int src, Stream stream) override {
throw std::runtime_error("[nccl] Recv not supported in NCCL backend.");
}
void all_max(const array& input, array& output, Stream stream) override {
throw std::runtime_error("[nccl] All max not supported in NCCL backend.");
}
void all_min(const array& input, array& output, Stream stream) override {
throw std::runtime_error("[nccl] All min not supported in NCCL backend.");
}
template <typename T>
void all_reduce_impl(
const array& input,
array& output,
Stream stream,
ncclDataType_t dt,
ncclRedOp_t op) {
auto& encoder = cu::get_command_encoder(stream);
CHECK_NCCL(ncclAllReduce(
input.data<T>(),
output.data<T>(),
input.size(),
dt,
op,
comm_,
encoder.stream()));
}
int rank_, size_;
std::string initMethod_;
ncclUniqueId uniqueId_;
ncclComm_t comm_;
bool initialized_ = false;
};
bool is_available() {
return true;
}
namespace detail {
static std::string get_env_var_or_throw(const char* env_var_name) {
const char* value = std::getenv(env_var_name);
if (value == nullptr) {
std::ostringstream msg;
msg << "[nccl] Required environment variable '" << env_var_name
<< "' is not set. "
<< "Please set it before initializing the distributed backend.";
throw std::runtime_error(msg.str());
}
return std::string(value);
}
} // namespace detail
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
std::string host = detail::get_env_var_or_throw("NCCL_HOST_IP");
std::string port = detail::get_env_var_or_throw("NCCL_PORT");
std::string rank_str = detail::get_env_var_or_throw("MLX_RANK");
std::string n_nodes_str = detail::get_env_var_or_throw("MLX_WORLD_SIZE");
int rank = std::stoi(rank_str);
int n_nodes = std::stoi(n_nodes_str);
std::string init_method = "tcp://" + host + ":" + port;
return std::make_shared<NCCLGroup>(rank, n_nodes, init_method);
}
} // namespace mlx::core::distributed::nccl

View File

@@ -0,0 +1,12 @@
// Copyright © 2024 Apple Inc.
#include "mlx/distributed/distributed.h"
namespace mlx::core::distributed::nccl {
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
bool is_available();
std::shared_ptr<GroupImpl> init(bool strict = false);
} // namespace mlx::core::distributed::nccl

View File

@@ -0,0 +1,20 @@
// Copyright © 2024 Apple Inc.
#include "mlx/distributed/nccl/nccl.h"
namespace mlx::core::distributed::nccl {
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
bool is_available() {
return false;
}
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
if (strict) {
throw std::runtime_error("Cannot initialize nccl distributed backend.");
}
return nullptr;
}
} // namespace mlx::core::distributed::nccl

View File

@@ -2,6 +2,9 @@
#include <sstream>
#include "mlx/backend/cuda/cuda.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/ops.h"
#include "mlx/distributed/primitives.h"
@@ -28,11 +31,12 @@ array all_sum(
if (group.size() == 1) {
return x;
}
auto stream = detail::communication_stream(group, s);
return array(
x.shape(),
x.dtype(),
std::make_shared<AllReduce>(
to_stream(s, Device::cpu), group, AllReduce::Sum),
std::make_shared<AllReduce>(stream, group, AllReduce::Sum),
{x});
}
@@ -45,11 +49,12 @@ array all_max(
if (group.size() == 1) {
return x;
}
auto stream = detail::communication_stream(group, s);
return array(
x.shape(),
x.dtype(),
std::make_shared<AllReduce>(
to_stream(s, Device::cpu), group, AllReduce::Max),
std::make_shared<AllReduce>(stream, group, AllReduce::Max),
{x});
}
@@ -62,11 +67,12 @@ array all_min(
if (group.size() == 1) {
return x;
}
auto stream = detail::communication_stream(group, s);
return array(
x.shape(),
x.dtype(),
std::make_shared<AllReduce>(
to_stream(s, Device::cpu), group, AllReduce::Min),
std::make_shared<AllReduce>(stream, group, AllReduce::Min),
{x});
}
@@ -79,6 +85,7 @@ array all_gather(
if (group.size() == 1) {
return x;
}
auto stream = detail::communication_stream(group, s);
auto result_shape = x.shape();
if (result_shape.size() == 0) {
@@ -89,7 +96,7 @@ array all_gather(
return array(
std::move(result_shape),
x.dtype(),
std::make_shared<AllGather>(to_stream(s, Device::cpu), group),
std::make_shared<AllGather>(stream, group),
{x});
}
@@ -103,6 +110,7 @@ array send(
if (group.size() == 1) {
throw std::invalid_argument("Cannot send to a singleton group");
}
auto stream = detail::communication_stream(group, s);
if (dst < 0 || dst >= group.size()) {
std::ostringstream msg;
@@ -112,10 +120,7 @@ array send(
}
return array(
x.shape(),
x.dtype(),
std::make_shared<Send>(to_stream(s, Device::cpu), group, dst),
{x});
x.shape(), x.dtype(), std::make_shared<Send>(stream, group, dst), {x});
}
array recv(
@@ -129,6 +134,7 @@ array recv(
if (group.size() == 1) {
throw std::invalid_argument("Cannot recv from a singleton group");
}
auto stream = detail::communication_stream(group, s);
if (src < 0 || src >= group.size()) {
std::ostringstream msg;
@@ -139,7 +145,7 @@ array recv(
return array(
std::move(shape),
std::move(dtype),
std::make_shared<Recv>(to_stream(s, Device::cpu), group, src),
std::make_shared<Recv>(stream, group, src),
std::vector<array>{});
}

View File

@@ -619,6 +619,10 @@ class RingGroup : public GroupImpl {
}
}
Stream communication_stream(StreamOrDevice s) override {
return to_stream(s, Device::cpu);
}
int rank() override {
return rank_;
}