mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-23 05:47:46 +08:00
NCCL backend (#2476)
This commit is contained in:
parent
e843c4d8d5
commit
9392fc3f88
@ -222,6 +222,7 @@ jobs:
|
|||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install libcudnn9-dev-cuda-12
|
sudo apt-get install libcudnn9-dev-cuda-12
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
|
sudo apt-get install libnccl2 libnccl-dev
|
||||||
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
|
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
|
||||||
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
|
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
|
||||||
rm -rf ccache-4.11.3-linux-x86_64
|
rm -rf ccache-4.11.3-linux-x86_64
|
||||||
|
54
cmake/FindNCCL.cmake
Normal file
54
cmake/FindNCCL.cmake
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
|
||||||
|
# directories.
|
||||||
|
|
||||||
|
set(NCCL_ROOT_DIR
|
||||||
|
$ENV{NCCL_ROOT_DIR}
|
||||||
|
CACHE PATH "Folder contains NVIDIA NCCL")
|
||||||
|
|
||||||
|
find_path(
|
||||||
|
NCCL_INCLUDE_DIRS
|
||||||
|
NAMES nccl.h
|
||||||
|
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/include)
|
||||||
|
|
||||||
|
if($ENV{USE_STATIC_NCCL})
|
||||||
|
message(
|
||||||
|
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
|
||||||
|
set(NCCL_LIBNAME "libnccl_static.a")
|
||||||
|
else()
|
||||||
|
set(NCCL_LIBNAME "nccl")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
find_library(
|
||||||
|
NCCL_LIBRARIES
|
||||||
|
NAMES ${NCCL_LIBNAME}
|
||||||
|
HINTS ${NCCL_LIB_DIR}
|
||||||
|
${NCCL_ROOT_DIR}
|
||||||
|
${NCCL_ROOT_DIR}/lib
|
||||||
|
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
|
||||||
|
${NCCL_ROOT_DIR}/lib64
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/lib
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
|
||||||
|
|
||||||
|
include(FindPackageHandleStandardArgs)
|
||||||
|
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
|
||||||
|
NCCL_LIBRARIES)
|
||||||
|
|
||||||
|
if(NCCL_FOUND)
|
||||||
|
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
|
||||||
|
message(
|
||||||
|
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
|
||||||
|
file(
|
||||||
|
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
|
||||||
|
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
|
||||||
|
LIMIT_COUNT 1)
|
||||||
|
if(NCCL_MAJOR_VERSION_DEFINED)
|
||||||
|
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
|
||||||
|
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
|
||||||
|
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
|
||||||
|
endif()
|
||||||
|
message(
|
||||||
|
STATUS
|
||||||
|
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
|
||||||
|
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
|
||||||
|
endif()
|
@ -271,7 +271,7 @@ and the CUDA toolkit. For example on Ubuntu, run the following:
|
|||||||
dpkg -i cuda-keyring_1.1-1_all.deb
|
dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
apt-get update -y
|
apt-get update -y
|
||||||
apt-get -y install cuda-toolkit-12-9
|
apt-get -y install cuda-toolkit-12-9
|
||||||
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
|
||||||
|
|
||||||
|
|
||||||
When building either the Python or C++ APIs make sure to pass the cmake flag
|
When building either the Python or C++ APIs make sure to pass the cmake flag
|
||||||
|
@ -22,6 +22,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||||
|
51
mlx/backend/cuda/distributed.cu
Normal file
51
mlx/backend/cuda/distributed.cu
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/distributed/primitives.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
namespace distributed {
|
||||||
|
void AllReduce::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
|
auto& input = inputs[0];
|
||||||
|
auto& output = outputs[0];
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(stream());
|
||||||
|
|
||||||
|
if (input.is_donatable()) {
|
||||||
|
output.copy_shared_buffer(input);
|
||||||
|
} else {
|
||||||
|
output.set_data(allocator::malloc(output.nbytes()));
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.set_input_array(input);
|
||||||
|
encoder.set_output_array(output);
|
||||||
|
|
||||||
|
auto capture = encoder.capture_context();
|
||||||
|
auto& s = stream();
|
||||||
|
|
||||||
|
switch (reduce_type_) {
|
||||||
|
case Sum:
|
||||||
|
distributed::detail::all_sum(group(), input, output, s);
|
||||||
|
break;
|
||||||
|
case Max:
|
||||||
|
distributed::detail::all_max(group(), input, output, s);
|
||||||
|
break;
|
||||||
|
case Min:
|
||||||
|
distributed::detail::all_min(group(), input, output, s);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Only all reduce sum, max, and min are supported.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace distributed
|
||||||
|
} // namespace mlx::core
|
@ -42,7 +42,6 @@ NO_GPU_MULTI(Eig)
|
|||||||
NO_GPU_MULTI(Eigh)
|
NO_GPU_MULTI(Eigh)
|
||||||
|
|
||||||
namespace distributed {
|
namespace distributed {
|
||||||
NO_GPU_MULTI(AllReduce)
|
|
||||||
NO_GPU_MULTI(AllGather)
|
NO_GPU_MULTI(AllGather)
|
||||||
NO_GPU_MULTI(Send)
|
NO_GPU_MULTI(Send)
|
||||||
NO_GPU_MULTI(Recv)
|
NO_GPU_MULTI(Recv)
|
||||||
|
@ -6,3 +6,4 @@ target_sources(
|
|||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
|
||||||
|
@ -5,12 +5,17 @@
|
|||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/distributed_impl.h"
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
#include "mlx/distributed/mpi/mpi.h"
|
#include "mlx/distributed/mpi/mpi.h"
|
||||||
|
#include "mlx/distributed/nccl/nccl.h"
|
||||||
#include "mlx/distributed/ring/ring.h"
|
#include "mlx/distributed/ring/ring.h"
|
||||||
|
|
||||||
namespace mlx::core::distributed {
|
namespace mlx::core::distributed {
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
|
|
||||||
|
Stream communication_stream(Group group, StreamOrDevice s /* = {} */) {
|
||||||
|
return group.raw_group()->communication_stream(s);
|
||||||
|
}
|
||||||
|
|
||||||
void all_sum(Group group, const array& input, array& output, Stream stream) {
|
void all_sum(Group group, const array& input, array& output, Stream stream) {
|
||||||
group.raw_group()->all_sum(input, output, stream);
|
group.raw_group()->all_sum(input, output, stream);
|
||||||
}
|
}
|
||||||
@ -37,6 +42,10 @@ void recv(Group group, array& out, int src, Stream stream) {
|
|||||||
|
|
||||||
class EmptyGroup : public GroupImpl {
|
class EmptyGroup : public GroupImpl {
|
||||||
public:
|
public:
|
||||||
|
Stream communication_stream(StreamOrDevice s) override {
|
||||||
|
return to_stream(s);
|
||||||
|
}
|
||||||
|
|
||||||
int rank() override {
|
int rank() override {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
@ -80,7 +89,7 @@ class EmptyGroup : public GroupImpl {
|
|||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
return mpi::is_available() || ring::is_available();
|
return mpi::is_available() || ring::is_available() || nccl::is_available();
|
||||||
}
|
}
|
||||||
|
|
||||||
int Group::rank() const {
|
int Group::rank() const {
|
||||||
@ -111,6 +120,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
|||||||
group = mpi::init(strict);
|
group = mpi::init(strict);
|
||||||
} else if (bk == "ring") {
|
} else if (bk == "ring") {
|
||||||
group = ring::init(strict);
|
group = ring::init(strict);
|
||||||
|
} else if (bk == "nccl") {
|
||||||
|
group = nccl::init(strict);
|
||||||
} else if (bk == "any") {
|
} else if (bk == "any") {
|
||||||
group = ring::init(false);
|
group = ring::init(false);
|
||||||
bk_ = "ring";
|
bk_ = "ring";
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core::distributed {
|
namespace mlx::core::distributed {
|
||||||
|
|
||||||
|
@ -13,10 +13,15 @@ class GroupImpl {
|
|||||||
public:
|
public:
|
||||||
virtual ~GroupImpl() {}
|
virtual ~GroupImpl() {}
|
||||||
|
|
||||||
|
// Choose the stream this communication group can operate on
|
||||||
|
virtual Stream communication_stream(StreamOrDevice s = {}) = 0;
|
||||||
|
|
||||||
|
// Group operations
|
||||||
virtual int rank() = 0;
|
virtual int rank() = 0;
|
||||||
virtual int size() = 0;
|
virtual int size() = 0;
|
||||||
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;
|
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;
|
||||||
|
|
||||||
|
// Actual communication operations
|
||||||
virtual void all_sum(const array& input, array& output, Stream stream) = 0;
|
virtual void all_sum(const array& input, array& output, Stream stream) = 0;
|
||||||
virtual void all_gather(const array& input, array& output, Stream stream) = 0;
|
virtual void all_gather(const array& input, array& output, Stream stream) = 0;
|
||||||
virtual void send(const array& input, int dst, Stream stream) = 0;
|
virtual void send(const array& input, int dst, Stream stream) = 0;
|
||||||
@ -25,6 +30,9 @@ class GroupImpl {
|
|||||||
virtual void all_min(const array& input, array& output, Stream stream) = 0;
|
virtual void all_min(const array& input, array& output, Stream stream) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/* Define the MLX stream that the communication should happen in. */
|
||||||
|
Stream communication_stream(Group group, StreamOrDevice s = {});
|
||||||
|
|
||||||
/* Perform an all reduce sum operation */
|
/* Perform an all reduce sum operation */
|
||||||
void all_sum(Group group, const array& input, array& output, Stream stream);
|
void all_sum(Group group, const array& input, array& output, Stream stream);
|
||||||
|
|
||||||
|
@ -349,6 +349,10 @@ class MPIGroup : public GroupImpl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Stream communication_stream(StreamOrDevice s) override {
|
||||||
|
return to_stream(s, Device::cpu);
|
||||||
|
}
|
||||||
|
|
||||||
int rank() override {
|
int rank() override {
|
||||||
if (rank_ < 0) {
|
if (rank_ < 0) {
|
||||||
mpi().rank(comm_, &rank_);
|
mpi().rank(comm_, &rank_);
|
||||||
|
8
mlx/distributed/nccl/CMakeLists.txt
Normal file
8
mlx/distributed/nccl/CMakeLists.txt
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
if(MLX_BUILD_CUDA)
|
||||||
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nccl.cpp)
|
||||||
|
find_package(NCCL REQUIRED)
|
||||||
|
target_link_libraries(mlx PRIVATE ${NCCL_LIBRARIES})
|
||||||
|
target_include_directories(mlx PRIVATE ${NCCL_INCLUDE_DIRS})
|
||||||
|
else()
|
||||||
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_nccl.cpp)
|
||||||
|
endif()
|
359
mlx/distributed/nccl/nccl.cpp
Normal file
359
mlx/distributed/nccl/nccl.cpp
Normal file
@ -0,0 +1,359 @@
|
|||||||
|
#include <arpa/inet.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <nccl.h>
|
||||||
|
#include <netdb.h>
|
||||||
|
#include <sys/socket.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
|
#include <iostream>
|
||||||
|
#include <mutex>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/distributed/distributed.h"
|
||||||
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::nccl {
|
||||||
|
|
||||||
|
#define CHECK_CUDA(cmd) \
|
||||||
|
do { \
|
||||||
|
cudaError_t e = cmd; \
|
||||||
|
if (e != cudaSuccess) { \
|
||||||
|
fprintf( \
|
||||||
|
stderr, \
|
||||||
|
"CUDA error %s:%d '%s'\n", \
|
||||||
|
__FILE__, \
|
||||||
|
__LINE__, \
|
||||||
|
cudaGetErrorString(e)); \
|
||||||
|
exit(1); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define CHECK_NCCL(cmd) \
|
||||||
|
do { \
|
||||||
|
ncclResult_t r = cmd; \
|
||||||
|
if (r != ncclSuccess) { \
|
||||||
|
fprintf( \
|
||||||
|
stderr, \
|
||||||
|
"NCCL error %s:%d '%s'\n", \
|
||||||
|
__FILE__, \
|
||||||
|
__LINE__, \
|
||||||
|
ncclGetErrorString(r)); \
|
||||||
|
exit(1); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define MLX_NCCL_TYPE_LIST(X) \
|
||||||
|
X(int8_t, ncclChar) \
|
||||||
|
X(uint8_t, ncclUint8) \
|
||||||
|
X(int32_t, ncclInt) \
|
||||||
|
X(uint32_t, ncclUint32) \
|
||||||
|
X(int64_t, ncclInt64) \
|
||||||
|
X(uint64_t, ncclUint64) \
|
||||||
|
X(float16_t, ncclHalf) \
|
||||||
|
X(bfloat16_t, ncclBfloat16) \
|
||||||
|
X(float, ncclFloat) \
|
||||||
|
X(double, ncclDouble)
|
||||||
|
|
||||||
|
template <class>
|
||||||
|
struct nccl_map {
|
||||||
|
static constexpr bool ok = false; // default: unsupported
|
||||||
|
};
|
||||||
|
|
||||||
|
#define MLX_DEF_NCCL_MAP(T, E) \
|
||||||
|
template <> \
|
||||||
|
struct nccl_map<T> { \
|
||||||
|
static constexpr bool ok = true; \
|
||||||
|
static constexpr ncclDataType_t value = E; \
|
||||||
|
};
|
||||||
|
|
||||||
|
MLX_NCCL_TYPE_LIST(MLX_DEF_NCCL_MAP)
|
||||||
|
#undef MLX_DEF_NCCL_MAP
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
void dispatch_dtype(const array& arr, F&& f) {
|
||||||
|
dispatch_all_types(arr.dtype(), [&](auto type_tag) {
|
||||||
|
using T = MLX_GET_TYPE(type_tag);
|
||||||
|
if constexpr (nccl_map<T>::ok) {
|
||||||
|
f(type_tag, nccl_map<T>::value);
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument("[nccl] Unknown or unsupported dtype");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void sendAll(int sock, const void* buf, size_t len) {
|
||||||
|
const char* ptr = reinterpret_cast<const char*>(buf);
|
||||||
|
while (len > 0) {
|
||||||
|
ssize_t sent = send(sock, ptr, len, 0);
|
||||||
|
if (sent <= 0) {
|
||||||
|
perror("send");
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
ptr += sent;
|
||||||
|
len -= sent;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void recvAll(int sock, void* buf, size_t len) {
|
||||||
|
char* ptr = reinterpret_cast<char*>(buf);
|
||||||
|
while (len > 0) {
|
||||||
|
ssize_t rec = recv(sock, ptr, len, 0);
|
||||||
|
if (rec <= 0) {
|
||||||
|
perror("recv");
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
ptr += rec;
|
||||||
|
len -= rec;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void bootstrap_unique_id(
|
||||||
|
ncclUniqueId& id,
|
||||||
|
int rank,
|
||||||
|
int size,
|
||||||
|
const std::string& initMethod) {
|
||||||
|
// Parse the init method to extract the host and port
|
||||||
|
if (initMethod.rfind("tcp://", 0) != 0)
|
||||||
|
throw;
|
||||||
|
auto hostport = initMethod.substr(6);
|
||||||
|
auto colon = hostport.find(':');
|
||||||
|
std::string host = hostport.substr(0, colon);
|
||||||
|
int port = std::stoi(hostport.substr(colon + 1));
|
||||||
|
|
||||||
|
if (rank == 0) {
|
||||||
|
// create a unique id on the rank 0
|
||||||
|
CHECK_NCCL(ncclGetUniqueId(&id));
|
||||||
|
|
||||||
|
// create a socket to send the unique id to all other ranks
|
||||||
|
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||||
|
|
||||||
|
if (sock < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] Couldn't create socket (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
sockaddr_in serv = {};
|
||||||
|
serv.sin_family = AF_INET;
|
||||||
|
serv.sin_addr.s_addr = htonl(INADDR_ANY);
|
||||||
|
serv.sin_port = htons(port);
|
||||||
|
|
||||||
|
int reuse = 1;
|
||||||
|
// Without this, if rank-0 crashes or restarts process quickly,
|
||||||
|
// the OS might refuse to let binding to the same port, so reuse
|
||||||
|
|
||||||
|
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] setsockopt() failed: " << strerror(errno);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (bind(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] bind() failed: " << strerror(errno);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
if (listen(sock, size - 1) < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] listen() failed: " << strerror(errno);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int peer = 1; peer < size; ++peer) {
|
||||||
|
int conn = accept(sock, nullptr, nullptr);
|
||||||
|
if (conn < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] accept() failed: " << strerror(errno);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
sendAll(conn, &id, sizeof(id));
|
||||||
|
close(conn);
|
||||||
|
}
|
||||||
|
close(sock);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// Here just wanted to make show that rank 0 has enough time to bind
|
||||||
|
// so we will retry to connect until max attempts
|
||||||
|
|
||||||
|
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||||
|
if (sock < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] socket() failed: " << strerror(errno);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
hostent* he = gethostbyname(host.c_str());
|
||||||
|
if (!he) {
|
||||||
|
throw std::runtime_error("[nccl] lookup failed for host: " + host);
|
||||||
|
}
|
||||||
|
sockaddr_in serv = {};
|
||||||
|
serv.sin_family = AF_INET;
|
||||||
|
memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length);
|
||||||
|
serv.sin_port = htons(port);
|
||||||
|
|
||||||
|
const int max_retries = 30;
|
||||||
|
int attempt = 0;
|
||||||
|
bool connected = false;
|
||||||
|
|
||||||
|
for (attempt = 0; attempt < max_retries; ++attempt) {
|
||||||
|
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
|
||||||
|
0) {
|
||||||
|
connected = true;
|
||||||
|
std::cout << "[Rank " << rank << "] Connected successfully on attempt "
|
||||||
|
<< attempt + 1 << std::endl;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (errno != ECONNREFUSED) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!connected) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[Rank " << rank << "] connect() failed after " << attempt
|
||||||
|
<< " retries: " << strerror(errno);
|
||||||
|
close(sock);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
recvAll(sock, &id, sizeof(id));
|
||||||
|
close(sock);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
class NCCLGroup : public GroupImpl {
|
||||||
|
public:
|
||||||
|
NCCLGroup(int worldRank, int worldSize, const std::string initMethod)
|
||||||
|
: rank_(worldRank),
|
||||||
|
size_(worldSize),
|
||||||
|
comm_(nullptr),
|
||||||
|
initMethod_(initMethod) {
|
||||||
|
if (initialized_)
|
||||||
|
return;
|
||||||
|
int ndev;
|
||||||
|
CHECK_CUDA(cudaGetDeviceCount(&ndev));
|
||||||
|
CHECK_CUDA(cudaSetDevice(rank_ % ndev));
|
||||||
|
detail::bootstrap_unique_id(uniqueId_, rank_, size_, initMethod_);
|
||||||
|
CHECK_NCCL(ncclCommInitRank(&comm_, size_, uniqueId_, rank_));
|
||||||
|
initialized_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
~NCCLGroup() {
|
||||||
|
ncclCommDestroy(comm_);
|
||||||
|
ncclGroupEnd();
|
||||||
|
initialized_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
Stream communication_stream(StreamOrDevice s) override {
|
||||||
|
return to_stream(s, Device::gpu);
|
||||||
|
}
|
||||||
|
|
||||||
|
int rank() override {
|
||||||
|
return rank_;
|
||||||
|
}
|
||||||
|
|
||||||
|
int size() override {
|
||||||
|
return size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_sum(const array& input, array& output, Stream stream) override {
|
||||||
|
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
|
using T = typename decltype(type_tag)::type;
|
||||||
|
all_reduce_impl<T>(input, output, stream, dt, ncclSum);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||||
|
throw std::runtime_error("[nccl] Group split not supported.");
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_gather(const array& input, array& output, Stream stream) override {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[nccl] All gather not supported in NCCL backend.");
|
||||||
|
}
|
||||||
|
|
||||||
|
void send(const array& input, int dst, Stream stream) override {
|
||||||
|
throw std::runtime_error("[nccl] Send not supported in NCCL backend.");
|
||||||
|
}
|
||||||
|
|
||||||
|
void recv(array& output, int src, Stream stream) override {
|
||||||
|
throw std::runtime_error("[nccl] Recv not supported in NCCL backend.");
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_max(const array& input, array& output, Stream stream) override {
|
||||||
|
throw std::runtime_error("[nccl] All max not supported in NCCL backend.");
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_min(const array& input, array& output, Stream stream) override {
|
||||||
|
throw std::runtime_error("[nccl] All min not supported in NCCL backend.");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void all_reduce_impl(
|
||||||
|
const array& input,
|
||||||
|
array& output,
|
||||||
|
Stream stream,
|
||||||
|
ncclDataType_t dt,
|
||||||
|
ncclRedOp_t op) {
|
||||||
|
auto& encoder = cu::get_command_encoder(stream);
|
||||||
|
|
||||||
|
CHECK_NCCL(ncclAllReduce(
|
||||||
|
input.data<T>(),
|
||||||
|
output.data<T>(),
|
||||||
|
input.size(),
|
||||||
|
dt,
|
||||||
|
op,
|
||||||
|
comm_,
|
||||||
|
encoder.stream()));
|
||||||
|
}
|
||||||
|
|
||||||
|
int rank_, size_;
|
||||||
|
std::string initMethod_;
|
||||||
|
ncclUniqueId uniqueId_;
|
||||||
|
ncclComm_t comm_;
|
||||||
|
bool initialized_ = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
static std::string get_env_var_or_throw(const char* env_var_name) {
|
||||||
|
const char* value = std::getenv(env_var_name);
|
||||||
|
if (value == nullptr) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] Required environment variable '" << env_var_name
|
||||||
|
<< "' is not set. "
|
||||||
|
<< "Please set it before initializing the distributed backend.";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
return std::string(value);
|
||||||
|
}
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||||
|
std::string host = detail::get_env_var_or_throw("NCCL_HOST_IP");
|
||||||
|
std::string port = detail::get_env_var_or_throw("NCCL_PORT");
|
||||||
|
std::string rank_str = detail::get_env_var_or_throw("MLX_RANK");
|
||||||
|
std::string n_nodes_str = detail::get_env_var_or_throw("MLX_WORLD_SIZE");
|
||||||
|
|
||||||
|
int rank = std::stoi(rank_str);
|
||||||
|
int n_nodes = std::stoi(n_nodes_str);
|
||||||
|
std::string init_method = "tcp://" + host + ":" + port;
|
||||||
|
|
||||||
|
return std::make_shared<NCCLGroup>(rank, n_nodes, init_method);
|
||||||
|
}
|
||||||
|
} // namespace mlx::core::distributed::nccl
|
12
mlx/distributed/nccl/nccl.h
Normal file
12
mlx/distributed/nccl/nccl.h
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/distributed/distributed.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::nccl {
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
|
||||||
|
bool is_available();
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict = false);
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::nccl
|
20
mlx/distributed/nccl/no_nccl.cpp
Normal file
20
mlx/distributed/nccl/no_nccl.cpp
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/distributed/nccl/nccl.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::nccl {
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||||
|
if (strict) {
|
||||||
|
throw std::runtime_error("Cannot initialize nccl distributed backend.");
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::nccl
|
@ -2,6 +2,9 @@
|
|||||||
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/cuda.h"
|
||||||
|
#include "mlx/backend/metal/metal.h"
|
||||||
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
#include "mlx/distributed/ops.h"
|
#include "mlx/distributed/ops.h"
|
||||||
#include "mlx/distributed/primitives.h"
|
#include "mlx/distributed/primitives.h"
|
||||||
|
|
||||||
@ -28,11 +31,12 @@ array all_sum(
|
|||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
auto stream = detail::communication_stream(group, s);
|
||||||
|
|
||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<AllReduce>(
|
std::make_shared<AllReduce>(stream, group, AllReduce::Sum),
|
||||||
to_stream(s, Device::cpu), group, AllReduce::Sum),
|
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,11 +49,12 @@ array all_max(
|
|||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
auto stream = detail::communication_stream(group, s);
|
||||||
|
|
||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<AllReduce>(
|
std::make_shared<AllReduce>(stream, group, AllReduce::Max),
|
||||||
to_stream(s, Device::cpu), group, AllReduce::Max),
|
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -62,11 +67,12 @@ array all_min(
|
|||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
auto stream = detail::communication_stream(group, s);
|
||||||
|
|
||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<AllReduce>(
|
std::make_shared<AllReduce>(stream, group, AllReduce::Min),
|
||||||
to_stream(s, Device::cpu), group, AllReduce::Min),
|
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -79,6 +85,7 @@ array all_gather(
|
|||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
auto stream = detail::communication_stream(group, s);
|
||||||
|
|
||||||
auto result_shape = x.shape();
|
auto result_shape = x.shape();
|
||||||
if (result_shape.size() == 0) {
|
if (result_shape.size() == 0) {
|
||||||
@ -89,7 +96,7 @@ array all_gather(
|
|||||||
return array(
|
return array(
|
||||||
std::move(result_shape),
|
std::move(result_shape),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<AllGather>(to_stream(s, Device::cpu), group),
|
std::make_shared<AllGather>(stream, group),
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -103,6 +110,7 @@ array send(
|
|||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
throw std::invalid_argument("Cannot send to a singleton group");
|
throw std::invalid_argument("Cannot send to a singleton group");
|
||||||
}
|
}
|
||||||
|
auto stream = detail::communication_stream(group, s);
|
||||||
|
|
||||||
if (dst < 0 || dst >= group.size()) {
|
if (dst < 0 || dst >= group.size()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -112,10 +120,7 @@ array send(
|
|||||||
}
|
}
|
||||||
|
|
||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(), x.dtype(), std::make_shared<Send>(stream, group, dst), {x});
|
||||||
x.dtype(),
|
|
||||||
std::make_shared<Send>(to_stream(s, Device::cpu), group, dst),
|
|
||||||
{x});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array recv(
|
array recv(
|
||||||
@ -129,6 +134,7 @@ array recv(
|
|||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
throw std::invalid_argument("Cannot recv from a singleton group");
|
throw std::invalid_argument("Cannot recv from a singleton group");
|
||||||
}
|
}
|
||||||
|
auto stream = detail::communication_stream(group, s);
|
||||||
|
|
||||||
if (src < 0 || src >= group.size()) {
|
if (src < 0 || src >= group.size()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -139,7 +145,7 @@ array recv(
|
|||||||
return array(
|
return array(
|
||||||
std::move(shape),
|
std::move(shape),
|
||||||
std::move(dtype),
|
std::move(dtype),
|
||||||
std::make_shared<Recv>(to_stream(s, Device::cpu), group, src),
|
std::make_shared<Recv>(stream, group, src),
|
||||||
std::vector<array>{});
|
std::vector<array>{});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -619,6 +619,10 @@ class RingGroup : public GroupImpl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Stream communication_stream(StreamOrDevice s) override {
|
||||||
|
return to_stream(s, Device::cpu);
|
||||||
|
}
|
||||||
|
|
||||||
int rank() override {
|
int rank() override {
|
||||||
return rank_;
|
return rank_;
|
||||||
}
|
}
|
||||||
|
@ -415,6 +415,48 @@ def launch_mpi(parser, hosts, args, command):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def launch_nccl(parser, hosts, args, command):
|
||||||
|
master_host = hosts[0].ips[0]
|
||||||
|
|
||||||
|
if master_host != "127.0.0.1":
|
||||||
|
raise ValueError("The NCCL backend only supports localhost for now. ")
|
||||||
|
master_port = args.nccl_port
|
||||||
|
world_size = len(hosts)
|
||||||
|
|
||||||
|
base_env = os.environ.copy()
|
||||||
|
base_env.update(
|
||||||
|
{
|
||||||
|
"NCCL_DEBUG": "INFO",
|
||||||
|
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
|
||||||
|
"NCCL_HOST_IP": master_host,
|
||||||
|
"NCCL_PORT": str(master_port),
|
||||||
|
"MLX_WORLD_SIZE": str(world_size),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
procs = []
|
||||||
|
try:
|
||||||
|
for rank in range(world_size):
|
||||||
|
env = base_env.copy()
|
||||||
|
env["MLX_RANK"] = str(rank)
|
||||||
|
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.nproc_per_node)
|
||||||
|
p = Popen(command, env=env)
|
||||||
|
procs.append(p)
|
||||||
|
|
||||||
|
for p in procs:
|
||||||
|
ret = p.wait()
|
||||||
|
if ret != 0:
|
||||||
|
raise RuntimeError(f"Rank process exited with {ret}")
|
||||||
|
|
||||||
|
except (RuntimeError, KeyboardInterrupt) as err:
|
||||||
|
for p in procs:
|
||||||
|
if p.poll() is None:
|
||||||
|
try:
|
||||||
|
p.kill()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def check_ssh_connections(hosts):
|
def check_ssh_connections(hosts):
|
||||||
results = [False] * len(hosts)
|
results = [False] * len(hosts)
|
||||||
|
|
||||||
@ -665,7 +707,7 @@ def distributed_config():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--backend",
|
"--backend",
|
||||||
choices=["ring", "mpi"],
|
choices=["ring", "mpi", "nccl"],
|
||||||
default="ring",
|
default="ring",
|
||||||
help="Which distributed backend to configure",
|
help="Which distributed backend to configure",
|
||||||
)
|
)
|
||||||
@ -737,7 +779,7 @@ def main():
|
|||||||
parser.add_argument("--hostfile", help="The file containing the hosts")
|
parser.add_argument("--hostfile", help="The file containing the hosts")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--backend",
|
"--backend",
|
||||||
choices=["ring", "mpi"],
|
choices=["ring", "mpi", "nccl"],
|
||||||
default="ring",
|
default="ring",
|
||||||
help="Which distributed backend to launch",
|
help="Which distributed backend to launch",
|
||||||
)
|
)
|
||||||
@ -769,6 +811,13 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cwd", help="Set the working directory on each node to the provided one"
|
"--cwd", help="Set the working directory on each node to the provided one"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--nccl-port",
|
||||||
|
type=int,
|
||||||
|
default=12345,
|
||||||
|
help="The port to use for the NCCL communication (only for nccl backend)",
|
||||||
|
)
|
||||||
|
|
||||||
args, rest = parser.parse_known_args()
|
args, rest = parser.parse_known_args()
|
||||||
if rest[0] == "--":
|
if rest[0] == "--":
|
||||||
rest.pop(0)
|
rest.pop(0)
|
||||||
@ -799,8 +848,10 @@ def main():
|
|||||||
# Launch
|
# Launch
|
||||||
if args.backend == "ring":
|
if args.backend == "ring":
|
||||||
launch_ring(parser, hosts, args, rest)
|
launch_ring(parser, hosts, args, rest)
|
||||||
elif args.backend == "mpi":
|
if args.backend == "mpi":
|
||||||
launch_mpi(parser, hosts, args, rest)
|
launch_mpi(parser, hosts, args, rest)
|
||||||
|
if args.backend == "nccl":
|
||||||
|
launch_nccl(parser, hosts, args, rest)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -76,6 +76,7 @@ def average_gradients(
|
|||||||
group: Optional[mx.distributed.Group] = None,
|
group: Optional[mx.distributed.Group] = None,
|
||||||
all_reduce_size: int = 32 * 1024**2,
|
all_reduce_size: int = 32 * 1024**2,
|
||||||
communication_type: Optional[mx.Dtype] = None,
|
communication_type: Optional[mx.Dtype] = None,
|
||||||
|
stream: mx.Stream = mx.cpu,
|
||||||
):
|
):
|
||||||
"""Average the gradients across the distributed processes in the passed group.
|
"""Average the gradients across the distributed processes in the passed group.
|
||||||
|
|
||||||
@ -94,6 +95,7 @@ def average_gradients(
|
|||||||
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
|
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
|
||||||
type before performing the communication. Typically cast to a
|
type before performing the communication. Typically cast to a
|
||||||
smaller float to reduce the communication size. Default: ``None``.
|
smaller float to reduce the communication size. Default: ``None``.
|
||||||
|
stream (mlx.core.Stream): The stream to use for the reduction. Default: ``mlx.cpu``.
|
||||||
"""
|
"""
|
||||||
group = group or mx.distributed.init()
|
group = group or mx.distributed.init()
|
||||||
N = group.size()
|
N = group.size()
|
||||||
@ -104,7 +106,7 @@ def average_gradients(
|
|||||||
def _average(x):
|
def _average(x):
|
||||||
dt = x.dtype
|
dt = x.dtype
|
||||||
x = x.astype(communication_type) if communication_type is not None else x
|
x = x.astype(communication_type) if communication_type is not None else x
|
||||||
return mx.distributed.all_sum(x, stream=mx.cpu).astype(dt) / N
|
return mx.distributed.all_sum(x, stream=stream).astype(dt) / N
|
||||||
|
|
||||||
if all_reduce_size <= 0:
|
if all_reduce_size <= 0:
|
||||||
return tree_map(_average, gradients)
|
return tree_map(_average, gradients)
|
||||||
|
@ -79,7 +79,7 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
in case ``mx.distributed.is_available()`` returns False otherwise
|
in case ``mx.distributed.is_available()`` returns False otherwise
|
||||||
it throws a runtime error. Default: ``False``
|
it throws a runtime error. Default: ``False``
|
||||||
backend (str, optional): Which distributed backend to initialize.
|
backend (str, optional): Which distributed backend to initialize.
|
||||||
Possible values ``mpi``, ``ring``, ``any``. If set to ``any`` all
|
Possible values ``mpi``, ``ring``, ``nccl``, ``any``. If set to ``any`` all
|
||||||
available backends are tried and the first one that succeeds
|
available backends are tried and the first one that succeeds
|
||||||
becomes the global group which will be returned in subsequent
|
becomes the global group which will be returned in subsequent
|
||||||
calls. Default: ``any``
|
calls. Default: ``any``
|
||||||
|
284
python/tests/nccl_test_distributed.py
Normal file
284
python/tests/nccl_test_distributed.py
Normal file
@ -0,0 +1,284 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import mlx_tests
|
||||||
|
from mlx.nn.layers.distributed import shard_inplace, shard_linear
|
||||||
|
from mlx.nn.utils import average_gradients
|
||||||
|
|
||||||
|
|
||||||
|
class TestNCCLDistributed(mlx_tests.MLXTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
world = mx.distributed.init(strict=True, backend="nccl")
|
||||||
|
rank = world.rank()
|
||||||
|
mx.set_default_device(mx.Device(mx.gpu, rank % 8))
|
||||||
|
|
||||||
|
def test_all_reduce(self):
|
||||||
|
world = mx.distributed.init()
|
||||||
|
dtypes = [
|
||||||
|
(mx.int8, 0),
|
||||||
|
(mx.uint8, 0),
|
||||||
|
(mx.int32, 0),
|
||||||
|
(mx.uint32, 0),
|
||||||
|
(mx.float32, 1e-6),
|
||||||
|
(mx.float16, 5e-3),
|
||||||
|
(mx.bfloat16, 1e-1),
|
||||||
|
]
|
||||||
|
sizes = [
|
||||||
|
(7,),
|
||||||
|
(10,),
|
||||||
|
(1024,),
|
||||||
|
(1024, 1024),
|
||||||
|
]
|
||||||
|
key = mx.random.key(0)
|
||||||
|
|
||||||
|
for dt, rtol in dtypes:
|
||||||
|
for sh in sizes:
|
||||||
|
x = (
|
||||||
|
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
|
||||||
|
).astype(dt)
|
||||||
|
|
||||||
|
# All sum
|
||||||
|
y = mx.distributed.all_sum(x[world.rank()])
|
||||||
|
z = x.sum(0)
|
||||||
|
maxrelerror = (y - z).abs()
|
||||||
|
if rtol > 0:
|
||||||
|
maxrelerror /= z.abs()
|
||||||
|
maxrelerror = maxrelerror.max()
|
||||||
|
self.assertLessEqual(maxrelerror, rtol)
|
||||||
|
|
||||||
|
def test_average_gradients(self):
|
||||||
|
original_all_sum = mx.distributed.all_sum
|
||||||
|
n_calls = 0
|
||||||
|
xtype = None
|
||||||
|
|
||||||
|
def new_all_sum(x, **kwargs):
|
||||||
|
nonlocal n_calls
|
||||||
|
nonlocal xtype
|
||||||
|
|
||||||
|
n_calls += 1
|
||||||
|
if xtype is not None:
|
||||||
|
self.assertEqual(xtype, x.dtype)
|
||||||
|
|
||||||
|
return original_all_sum(x, **kwargs)
|
||||||
|
|
||||||
|
mx.distributed.all_sum = new_all_sum
|
||||||
|
try:
|
||||||
|
grads = [mx.ones(10) for i in range(10)]
|
||||||
|
new_grads = average_gradients(grads, stream=mx.gpu)
|
||||||
|
mx.eval(new_grads)
|
||||||
|
self.assertEqual(len(new_grads), 10)
|
||||||
|
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||||
|
self.assertEqual(n_calls, 1)
|
||||||
|
|
||||||
|
n_calls = 0
|
||||||
|
new_grads = average_gradients(grads, all_reduce_size=4 * 50, stream=mx.gpu)
|
||||||
|
mx.eval(new_grads)
|
||||||
|
self.assertEqual(len(new_grads), 10)
|
||||||
|
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||||
|
self.assertEqual(n_calls, 2)
|
||||||
|
|
||||||
|
n_calls = 0
|
||||||
|
new_grads = average_gradients(grads, all_reduce_size=0, stream=mx.gpu)
|
||||||
|
mx.eval(new_grads)
|
||||||
|
self.assertEqual(len(new_grads), 10)
|
||||||
|
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||||
|
self.assertEqual(n_calls, 10)
|
||||||
|
|
||||||
|
n_calls = 0
|
||||||
|
xtype = mx.float16
|
||||||
|
new_grads = average_gradients(
|
||||||
|
grads,
|
||||||
|
all_reduce_size=2 * 50,
|
||||||
|
communication_type=mx.float16,
|
||||||
|
stream=mx.gpu,
|
||||||
|
)
|
||||||
|
mx.eval(new_grads)
|
||||||
|
self.assertEqual(len(new_grads), 10)
|
||||||
|
self.assertTrue(all(g.dtype == mx.float32 for g in new_grads))
|
||||||
|
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||||
|
self.assertEqual(n_calls, 2)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
mx.distributed.all_sum = original_all_sum
|
||||||
|
|
||||||
|
def test_donation(self):
|
||||||
|
x = mx.random.normal((1024,))
|
||||||
|
mx.eval(x)
|
||||||
|
mx.synchronize()
|
||||||
|
|
||||||
|
mx.reset_peak_memory()
|
||||||
|
scale = mx.array(2.0)
|
||||||
|
y = mx.distributed.all_sum(x)
|
||||||
|
mx.eval(y)
|
||||||
|
mx.synchronize()
|
||||||
|
all_sum_only = mx.get_peak_memory()
|
||||||
|
y = mx.distributed.all_sum(x) * scale
|
||||||
|
mx.eval(y)
|
||||||
|
mx.synchronize()
|
||||||
|
all_sum_with_binary = mx.get_peak_memory()
|
||||||
|
|
||||||
|
self.assertEqual(all_sum_only, all_sum_with_binary)
|
||||||
|
|
||||||
|
def test_shard_linear(self):
|
||||||
|
# Seed the prng to have the same inputs and weights generated everywhere
|
||||||
|
mx.random.seed(0xF0F0F0F0)
|
||||||
|
|
||||||
|
# Prepare inputs
|
||||||
|
world = mx.distributed.init()
|
||||||
|
part = (
|
||||||
|
slice(None),
|
||||||
|
slice(
|
||||||
|
world.rank() * 1024 // world.size(),
|
||||||
|
(world.rank() + 1) * 1024 // world.size(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
x = mx.random.normal((4, 1024))
|
||||||
|
|
||||||
|
# Create and shard some linear layers
|
||||||
|
lin = nn.Linear(1024, 1024, bias=True)
|
||||||
|
slin1 = shard_linear(lin, "all-to-sharded")
|
||||||
|
slin2 = shard_linear(lin, "sharded-to-all")
|
||||||
|
y = lin(x)
|
||||||
|
y1 = slin1(x)
|
||||||
|
y2 = slin2(x[part])
|
||||||
|
self.assertTrue(mx.allclose(y, y2, atol=1e-4, rtol=1e-4))
|
||||||
|
self.assertTrue(mx.allclose(y[part], y1, atol=1e-4, rtol=1e-4))
|
||||||
|
|
||||||
|
# Check the backward works as expected
|
||||||
|
def dummy_loss(model, x, y):
|
||||||
|
return (model(x) * y).sum()
|
||||||
|
|
||||||
|
mod = nn.Sequential(
|
||||||
|
nn.Linear(128, 128),
|
||||||
|
nn.Linear(128, 128),
|
||||||
|
nn.Linear(128, 128),
|
||||||
|
nn.Linear(128, 128),
|
||||||
|
)
|
||||||
|
smod = nn.Sequential(
|
||||||
|
shard_linear(mod.layers[0], "all-to-sharded"),
|
||||||
|
shard_linear(mod.layers[1], "sharded-to-all"),
|
||||||
|
shard_linear(mod.layers[2], "all-to-sharded"),
|
||||||
|
shard_linear(mod.layers[3], "sharded-to-all"),
|
||||||
|
)
|
||||||
|
|
||||||
|
grad1 = nn.value_and_grad(mod, dummy_loss)
|
||||||
|
grad2 = nn.value_and_grad(smod, dummy_loss)
|
||||||
|
|
||||||
|
x = mx.random.normal((4, 128))
|
||||||
|
y = mx.random.normal((4, 128))
|
||||||
|
|
||||||
|
l1, g1 = grad1(mod, x, y)
|
||||||
|
l2, g2 = grad2(smod, x, y)
|
||||||
|
mx.eval(l1, g1, l2, g2)
|
||||||
|
|
||||||
|
part = slice(
|
||||||
|
world.rank() * 128 // world.size(), (world.rank() + 1) * 128 // world.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(mx.allclose(l1, l2))
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][0]["weight"][part],
|
||||||
|
g2["layers"][0]["weight"],
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][2]["weight"][part],
|
||||||
|
g2["layers"][2]["weight"],
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][1]["weight"][:, part],
|
||||||
|
g2["layers"][1]["weight"],
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][3]["weight"][:, part],
|
||||||
|
g2["layers"][3]["weight"],
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][0]["bias"][part],
|
||||||
|
g2["layers"][0]["bias"],
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][2]["bias"][part],
|
||||||
|
g2["layers"][2]["bias"],
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_shard_predicate(self):
|
||||||
|
mx.random.seed(0xF0F0F0F0)
|
||||||
|
|
||||||
|
class MyConv(nn.Module):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.aggregate = kwargs.pop("aggregate", False)
|
||||||
|
self.conv = nn.Conv2d(*args, **kwargs)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
if self.aggregate:
|
||||||
|
x = mx.distributed.all_sum(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def sharding(path, weight):
|
||||||
|
parts = path.split(".")
|
||||||
|
even = int(parts[1]) % 2 == 0
|
||||||
|
if even:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return -1 if parts[-1] != "bias" else None
|
||||||
|
|
||||||
|
mod = nn.Sequential(
|
||||||
|
MyConv(3, 128, kernel_size=3),
|
||||||
|
MyConv(128, 128, kernel_size=3),
|
||||||
|
MyConv(128, 128, kernel_size=3),
|
||||||
|
MyConv(128, 3, kernel_size=3),
|
||||||
|
)
|
||||||
|
smod = nn.Sequential(
|
||||||
|
MyConv(3, 128, kernel_size=3),
|
||||||
|
MyConv(128, 128, kernel_size=3, aggregate=True),
|
||||||
|
MyConv(128, 128, kernel_size=3),
|
||||||
|
MyConv(128, 3, kernel_size=3, aggregate=True),
|
||||||
|
)
|
||||||
|
smod.update(mod.parameters())
|
||||||
|
shard_inplace(smod, sharding)
|
||||||
|
|
||||||
|
x = mx.random.normal((4, 16, 16, 3))
|
||||||
|
y1 = mod(x)
|
||||||
|
y2 = smod(x)
|
||||||
|
self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
mlx_tests.MLXTestRunner()
|
Loading…
Reference in New Issue
Block a user