mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Merge bc6f00c00e into 828c5f1137
This commit is contained in:
54
cmake/FindNCCL.cmake
Normal file
54
cmake/FindNCCL.cmake
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
# FindNCCL.cmake
|
||||||
|
# This module finds the NVIDIA NCCL library and its include directories.
|
||||||
|
|
||||||
|
set(NCCL_ROOT_DIR
|
||||||
|
$ENV{NCCL_ROOT_DIR}
|
||||||
|
CACHE PATH "Folder contains NVIDIA NCCL")
|
||||||
|
|
||||||
|
find_path(
|
||||||
|
NCCL_INCLUDE_DIRS
|
||||||
|
NAMES nccl.h
|
||||||
|
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/include)
|
||||||
|
|
||||||
|
if($ENV{USE_STATIC_NCCL})
|
||||||
|
message(
|
||||||
|
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
|
||||||
|
set(NCCL_LIBNAME "libnccl_static.a")
|
||||||
|
else()
|
||||||
|
set(NCCL_LIBNAME "nccl")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
find_library(
|
||||||
|
NCCL_LIBRARIES
|
||||||
|
NAMES ${NCCL_LIBNAME}
|
||||||
|
HINTS ${NCCL_LIB_DIR}
|
||||||
|
${NCCL_ROOT_DIR}
|
||||||
|
${NCCL_ROOT_DIR}/lib
|
||||||
|
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
|
||||||
|
${NCCL_ROOT_DIR}/lib64
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/lib
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
|
||||||
|
|
||||||
|
include(FindPackageHandleStandardArgs)
|
||||||
|
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
|
||||||
|
NCCL_LIBRARIES)
|
||||||
|
|
||||||
|
if(NCCL_FOUND)
|
||||||
|
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
|
||||||
|
message(
|
||||||
|
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
|
||||||
|
file(
|
||||||
|
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
|
||||||
|
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
|
||||||
|
LIMIT_COUNT 1)
|
||||||
|
if(NCCL_MAJOR_VERSION_DEFINED)
|
||||||
|
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
|
||||||
|
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
|
||||||
|
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
|
||||||
|
endif()
|
||||||
|
message(
|
||||||
|
STATUS
|
||||||
|
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
|
||||||
|
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
|
||||||
|
endif()
|
||||||
@@ -19,6 +19,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||||
|
|||||||
87
mlx/backend/cuda/distributed.cu
Normal file
87
mlx/backend/cuda/distributed.cu
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/distributed/primitives.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
namespace distributed {
|
||||||
|
void AllReduce::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
// Here I assume for now that in is donatable and contiguous.
|
||||||
|
// TODO
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
|
auto& input = inputs[0];
|
||||||
|
auto& output = outputs[0];
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(stream());
|
||||||
|
output.set_data(allocator::malloc(output.nbytes()));
|
||||||
|
|
||||||
|
encoder.set_input_array(input);
|
||||||
|
encoder.set_output_array(output);
|
||||||
|
|
||||||
|
auto capture = encoder.capture_context();
|
||||||
|
auto& s = stream();
|
||||||
|
|
||||||
|
switch (reduce_type_) {
|
||||||
|
case Sum:
|
||||||
|
distributed::detail::all_sum(group(), input, output, s);
|
||||||
|
break;
|
||||||
|
case Max:
|
||||||
|
distributed::detail::all_max(group(), input, output, s);
|
||||||
|
break;
|
||||||
|
case Min:
|
||||||
|
distributed::detail::all_min(group(), input, output, s);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Only all reduce sum, max, and min are supported.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Send::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
// Here FOR NOW I assume that it is always row_contigious
|
||||||
|
// because not sure how to copy correctly
|
||||||
|
// TODO
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
|
distributed::detail::send(group(), inputs[0], dst_, stream());
|
||||||
|
outputs[0].copy_shared_buffer(inputs[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Recv::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() == 0);
|
||||||
|
assert(outputs.size() == 1);
|
||||||
|
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
|
||||||
|
distributed::detail::recv(group(), outputs[0], src_, stream());
|
||||||
|
}
|
||||||
|
|
||||||
|
void AllGather::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
// Here FOR NOW I assume that it is always row_contigious
|
||||||
|
// because not sure how to copy correctly
|
||||||
|
// TODO
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
|
auto& input = inputs[0];
|
||||||
|
auto& output = outputs[0];
|
||||||
|
|
||||||
|
output.copy_shared_buffer(input);
|
||||||
|
distributed::detail::all_gather(group(), input, output, stream());
|
||||||
|
}
|
||||||
|
}// namespace distributed
|
||||||
|
}
|
||||||
60
mlx/backend/cuda/iterators/strided_iterator.cuh
Normal file
60
mlx/backend/cuda/iterators/strided_iterator.cuh
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <thrust/iterator/iterator_adaptor.h>
|
||||||
|
#include <thrust/iterator/iterator_facade.h>
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
// RandomAccessIterator for strided access to array entries.
|
||||||
|
template <typename Iterator, typename Stride = int64_t>
|
||||||
|
class strided_iterator
|
||||||
|
: public thrust::
|
||||||
|
iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator> {
|
||||||
|
public:
|
||||||
|
using super_t =
|
||||||
|
thrust::iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator>;
|
||||||
|
|
||||||
|
using reference = typename super_t::reference;
|
||||||
|
using difference_type = typename super_t::difference_type;
|
||||||
|
|
||||||
|
__host__ __device__ strided_iterator(Iterator it, Stride stride)
|
||||||
|
: super_t(it), stride_(stride) {}
|
||||||
|
|
||||||
|
__host__ __device__ Stride stride() const {
|
||||||
|
return stride_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
friend class thrust::iterator_core_access;
|
||||||
|
|
||||||
|
__host__ __device__ bool equal(const strided_iterator& other) const {
|
||||||
|
return this->base() == other.base();
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ void advance(difference_type n) {
|
||||||
|
this->base_reference() += n * stride_;
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ void increment() {
|
||||||
|
this->base_reference() += stride_;
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ void decrement() {
|
||||||
|
this->base_reference() -= stride_;
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ difference_type
|
||||||
|
distance_to(const strided_iterator& other) const {
|
||||||
|
const difference_type dist = other.base() - this->base();
|
||||||
|
_CCCL_ASSERT(
|
||||||
|
dist % stride() == 0,
|
||||||
|
"Underlying iterator difference must be divisible by the stride");
|
||||||
|
return dist / stride();
|
||||||
|
}
|
||||||
|
|
||||||
|
Stride stride_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
||||||
@@ -6,6 +6,7 @@
|
|||||||
#include "mlx/backend/cuda/gemms/gemv.h"
|
#include "mlx/backend/cuda/gemms/gemv.h"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|||||||
@@ -57,11 +57,11 @@ NO_GPU(ScaledDotProductAttention)
|
|||||||
NO_GPU_MULTI(CustomKernel)
|
NO_GPU_MULTI(CustomKernel)
|
||||||
} // namespace fast
|
} // namespace fast
|
||||||
|
|
||||||
namespace distributed {
|
// namespace distributed {
|
||||||
NO_GPU_MULTI(AllReduce)
|
// NO_GPU_MULTI(AllReduce)
|
||||||
NO_GPU_MULTI(AllGather)
|
// NO_GPU_MULTI(AllGather)
|
||||||
NO_GPU_MULTI(Send)
|
// NO_GPU_MULTI(Send)
|
||||||
NO_GPU_MULTI(Recv)
|
// NO_GPU_MULTI(Recv)
|
||||||
} // namespace distributed
|
// } // namespace distributed
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
84
mlx/backend/cuda/reduce/segmented_reduce.cu
Normal file
84
mlx/backend/cuda/reduce/segmented_reduce.cu
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||||
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
|
|
||||||
|
#include <thrust/device_ptr.h>
|
||||||
|
#include <cub/device/device_reduce.cuh>
|
||||||
|
#include <cub/device/device_segmented_reduce.cuh>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
||||||
|
// Allocate temporary storage.
|
||||||
|
size_t size;
|
||||||
|
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...));
|
||||||
|
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||||
|
encoder.add_temporary(temp);
|
||||||
|
// Run op.
|
||||||
|
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data<void>(), size, args...));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
||||||
|
// Allocate temporary storage.
|
||||||
|
size_t size;
|
||||||
|
CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...));
|
||||||
|
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||||
|
encoder.add_temporary(temp);
|
||||||
|
// Run op.
|
||||||
|
CHECK_CUDA_ERROR(
|
||||||
|
cub::DeviceSegmentedReduce::Reduce(temp.data<void>(), size, args...));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MultiplyOp {
|
||||||
|
int factor;
|
||||||
|
__device__ int operator()(int i) {
|
||||||
|
return i * factor;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void segmented_reduce(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const ReductionPlan& plan) {
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
|
using InType = cuda_type_t<CTYPE>;
|
||||||
|
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||||
|
auto in_iter = cu::make_cast_iterator<OutType>(
|
||||||
|
thrust::device_pointer_cast(in.data<InType>()));
|
||||||
|
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>());
|
||||||
|
auto init = cu::ReduceInit<OP, InType>::value();
|
||||||
|
|
||||||
|
if (plan.type == ContiguousAllReduce) {
|
||||||
|
cub_all_reduce(
|
||||||
|
encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream);
|
||||||
|
} else if (plan.type == ContiguousReduce) {
|
||||||
|
auto offsets = thrust::make_transform_iterator(
|
||||||
|
thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()});
|
||||||
|
cub_segmented_reduce(
|
||||||
|
encoder,
|
||||||
|
in_iter,
|
||||||
|
out_ptr,
|
||||||
|
out.size(),
|
||||||
|
offsets,
|
||||||
|
offsets + 1,
|
||||||
|
OP(),
|
||||||
|
init,
|
||||||
|
stream);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Unsupported plan in segmented_reduce.");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -6,3 +6,4 @@ target_sources(
|
|||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
|
||||||
|
|||||||
@@ -2,9 +2,11 @@
|
|||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/distributed_impl.h"
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
#include "mlx/distributed/mpi/mpi.h"
|
#include "mlx/distributed/mpi/mpi.h"
|
||||||
|
#include "mlx/distributed/nccl/nccl.h"
|
||||||
#include "mlx/distributed/ring/ring.h"
|
#include "mlx/distributed/ring/ring.h"
|
||||||
|
|
||||||
namespace mlx::core::distributed {
|
namespace mlx::core::distributed {
|
||||||
@@ -80,7 +82,7 @@ class EmptyGroup : public GroupImpl {
|
|||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
return mpi::is_available() || ring::is_available();
|
return mpi::is_available() || ring::is_available() || nccl::is_available();
|
||||||
}
|
}
|
||||||
|
|
||||||
int Group::rank() const {
|
int Group::rank() const {
|
||||||
@@ -111,6 +113,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
|||||||
group = mpi::init(strict);
|
group = mpi::init(strict);
|
||||||
} else if (bk == "ring") {
|
} else if (bk == "ring") {
|
||||||
group = ring::init(strict);
|
group = ring::init(strict);
|
||||||
|
} else if (bk == "nccl") {
|
||||||
|
group = nccl::init(strict);
|
||||||
} else if (bk == "any") {
|
} else if (bk == "any") {
|
||||||
group = ring::init(false);
|
group = ring::init(false);
|
||||||
bk_ = "ring";
|
bk_ = "ring";
|
||||||
|
|||||||
8
mlx/distributed/nccl/CMakeLists.txt
Normal file
8
mlx/distributed/nccl/CMakeLists.txt
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
if(MLX_BUILD_CUDA)
|
||||||
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nccl.cpp)
|
||||||
|
find_package(NCCL REQUIRED)
|
||||||
|
target_link_libraries(mlx PRIVATE ${NCCL_LIBRARIES})
|
||||||
|
target_include_directories(mlx PRIVATE ${NCCL_INCLUDE_DIRS})
|
||||||
|
else()
|
||||||
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_nccl.cpp)
|
||||||
|
endif()
|
||||||
405
mlx/distributed/nccl/nccl.cpp
Normal file
405
mlx/distributed/nccl/nccl.cpp
Normal file
@@ -0,0 +1,405 @@
|
|||||||
|
#include <arpa/inet.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <nccl.h>
|
||||||
|
#include <netdb.h>
|
||||||
|
#include <sys/socket.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
|
#include <iostream>
|
||||||
|
#include <mutex>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/distributed/distributed.h"
|
||||||
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::nccl {
|
||||||
|
|
||||||
|
#define CHECK_CUDA(cmd) \
|
||||||
|
do { \
|
||||||
|
cudaError_t e = cmd; \
|
||||||
|
if (e != cudaSuccess) { \
|
||||||
|
fprintf( \
|
||||||
|
stderr, \
|
||||||
|
"CUDA error %s:%d '%s'\n", \
|
||||||
|
__FILE__, \
|
||||||
|
__LINE__, \
|
||||||
|
cudaGetErrorString(e)); \
|
||||||
|
exit(1); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define CHECK_NCCL(cmd) \
|
||||||
|
do { \
|
||||||
|
ncclResult_t r = cmd; \
|
||||||
|
if (r != ncclSuccess) { \
|
||||||
|
fprintf( \
|
||||||
|
stderr, \
|
||||||
|
"NCCL error %s:%d '%s'\n", \
|
||||||
|
__FILE__, \
|
||||||
|
__LINE__, \
|
||||||
|
ncclGetErrorString(r)); \
|
||||||
|
exit(1); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
inline void sendAll(int sock, const void* buf, size_t len) {
|
||||||
|
const char* ptr = reinterpret_cast<const char*>(buf);
|
||||||
|
while (len > 0) {
|
||||||
|
ssize_t sent = send(sock, ptr, len, 0);
|
||||||
|
if (sent <= 0) {
|
||||||
|
perror("send");
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
ptr += sent;
|
||||||
|
len -= sent;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void recvAll(int sock, void* buf, size_t len) {
|
||||||
|
char* ptr = reinterpret_cast<char*>(buf);
|
||||||
|
while (len > 0) {
|
||||||
|
ssize_t rec = recv(sock, ptr, len, 0);
|
||||||
|
if (rec <= 0) {
|
||||||
|
perror("recv");
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
ptr += rec;
|
||||||
|
len -= rec;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void bootstrap_unique_id(
|
||||||
|
ncclUniqueId& id,
|
||||||
|
int rank,
|
||||||
|
int size,
|
||||||
|
const std::string& initMethod) {
|
||||||
|
// Parse the init method to extract the host and port
|
||||||
|
if (initMethod.rfind("tcp://", 0) != 0)
|
||||||
|
throw;
|
||||||
|
auto hostport = initMethod.substr(6);
|
||||||
|
auto colon = hostport.find(':');
|
||||||
|
std::string host = hostport.substr(0, colon);
|
||||||
|
int port = std::stoi(hostport.substr(colon + 1));
|
||||||
|
|
||||||
|
if (rank == 0) {
|
||||||
|
// create a unique id on the rank 0
|
||||||
|
CHECK_NCCL(ncclGetUniqueId(&id));
|
||||||
|
|
||||||
|
// create a socket to send the unique id to all other ranks
|
||||||
|
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||||
|
|
||||||
|
if (sock < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] Couldn't create socket (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
sockaddr_in serv = {};
|
||||||
|
serv.sin_family = AF_INET;
|
||||||
|
serv.sin_addr.s_addr = htonl(INADDR_ANY);
|
||||||
|
serv.sin_port = htons(port);
|
||||||
|
|
||||||
|
int reuse = 1;
|
||||||
|
// Without this, if rank-0 crashes or restarts process quickly,
|
||||||
|
// the OS might refuse to let binding to the same port, so reuse
|
||||||
|
|
||||||
|
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] setsockopt() failed: " << strerror(errno);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (bind(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] bind() failed: " << strerror(errno);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
if (listen(sock, size - 1) < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] listen() failed: " << strerror(errno);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int peer = 1; peer < size; ++peer) {
|
||||||
|
int conn = accept(sock, nullptr, nullptr);
|
||||||
|
if (conn < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] accept() failed: " << strerror(errno);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
sendAll(conn, &id, sizeof(id));
|
||||||
|
close(conn);
|
||||||
|
}
|
||||||
|
close(sock);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// Here just wanted to make show that rank 0 has enough time to bind
|
||||||
|
// so we will retry to connect until max attempts
|
||||||
|
|
||||||
|
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||||
|
if (sock < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] socket() failed: " << strerror(errno);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
hostent* he = gethostbyname(host.c_str());
|
||||||
|
if (!he) {
|
||||||
|
throw std::runtime_error("[nccl] lookup failed for host: " + host);
|
||||||
|
}
|
||||||
|
sockaddr_in serv = {};
|
||||||
|
serv.sin_family = AF_INET;
|
||||||
|
memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length);
|
||||||
|
serv.sin_port = htons(port);
|
||||||
|
|
||||||
|
const int max_retries = 30;
|
||||||
|
int attempt = 0;
|
||||||
|
bool connected = false;
|
||||||
|
|
||||||
|
for (attempt = 0; attempt < max_retries; ++attempt) {
|
||||||
|
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
|
||||||
|
0) {
|
||||||
|
connected = true;
|
||||||
|
std::cout << "[Rank " << rank << "] Connected successfully on attempt "
|
||||||
|
<< attempt + 1 << std::endl;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (errno != ECONNREFUSED) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!connected) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[Rank " << rank << "] connect() failed after " << attempt
|
||||||
|
<< " retries: " << strerror(errno);
|
||||||
|
close(sock);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
recvAll(sock, &id, sizeof(id));
|
||||||
|
close(sock);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct type_identity {
|
||||||
|
using type = T;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
void dispatch_dtype(const array& arr, F&& f) {
|
||||||
|
switch (arr.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
throw std::invalid_argument("[nccl] Boolean arrays not supported");
|
||||||
|
case int8:
|
||||||
|
f(type_identity<int8_t>{}, ncclChar);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
f(type_identity<uint8_t>{}, ncclUint8);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
f(type_identity<int32_t>{}, ncclInt);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
f(type_identity<uint32_t>{}, ncclUint32);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
f(type_identity<int64_t>{}, ncclInt64);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
f(type_identity<uint64_t>{}, ncclUint64);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
f(type_identity<float16_t>{}, ncclHalf);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
f(type_identity<bfloat16_t>{}, ncclBfloat16);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
f(type_identity<float>{}, ncclFloat);
|
||||||
|
break;
|
||||||
|
case float64:
|
||||||
|
f(type_identity<double>{}, ncclDouble);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument("[nccl] Unknown or unsupported dtype");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
class NCCLGroup : public GroupImpl {
|
||||||
|
public:
|
||||||
|
NCCLGroup(int worldRank, int worldSize, const std::string initMethod)
|
||||||
|
: rank_(worldRank),
|
||||||
|
size_(worldSize),
|
||||||
|
comm_(nullptr),
|
||||||
|
initMethod_(initMethod) {
|
||||||
|
if (initialized_)
|
||||||
|
return;
|
||||||
|
int ndev;
|
||||||
|
CHECK_CUDA(cudaGetDeviceCount(&ndev));
|
||||||
|
CHECK_CUDA(cudaSetDevice(rank_ % ndev));
|
||||||
|
detail::bootstrap_unique_id(uniqueId_, rank_, size_, initMethod_);
|
||||||
|
CHECK_NCCL(ncclCommInitRank(&comm_, size_, uniqueId_, rank_));
|
||||||
|
initialized_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
~NCCLGroup() {
|
||||||
|
ncclCommDestroy(comm_);
|
||||||
|
ncclGroupEnd();
|
||||||
|
initialized_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int rank() override {
|
||||||
|
return rank_;
|
||||||
|
}
|
||||||
|
|
||||||
|
int size() override {
|
||||||
|
return size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_sum(const array& input, array& output, Stream stream) override {
|
||||||
|
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
|
using T = typename decltype(type_tag)::type;
|
||||||
|
all_reduce_impl<T>(
|
||||||
|
input,
|
||||||
|
output,
|
||||||
|
stream,
|
||||||
|
dt,
|
||||||
|
ncclSum);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||||
|
throw std::runtime_error("[nccl] Group split not supported.");
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_gather(const array& input, array& output, Stream stream) override {
|
||||||
|
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
|
auto& encoder = cu::get_command_encoder(stream);
|
||||||
|
|
||||||
|
using T = typename decltype(type_tag)::type;
|
||||||
|
CHECK_NCCL(ncclAllGather(
|
||||||
|
input.data<T>(),
|
||||||
|
output.data<T>(),
|
||||||
|
input.size(),
|
||||||
|
dt,
|
||||||
|
comm_,
|
||||||
|
encoder.stream()));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void send(const array& input, int dst, Stream stream) override {
|
||||||
|
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
|
auto& encoder = cu::get_command_encoder(stream);
|
||||||
|
|
||||||
|
using T = typename decltype(type_tag)::type;
|
||||||
|
CHECK_NCCL(ncclSend(
|
||||||
|
input.data<T>(),
|
||||||
|
input.size(),
|
||||||
|
dt,
|
||||||
|
dst,
|
||||||
|
comm_,
|
||||||
|
encoder.stream()));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void recv(array& output, int src, Stream stream) override {
|
||||||
|
detail::dispatch_dtype(output, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
|
using T = typename decltype(type_tag)::type;
|
||||||
|
auto& encoder = cu::get_command_encoder(stream);
|
||||||
|
|
||||||
|
CHECK_NCCL(ncclRecv(
|
||||||
|
output.data<T>(),
|
||||||
|
output.size(),
|
||||||
|
dt,
|
||||||
|
src,
|
||||||
|
comm_,
|
||||||
|
encoder.stream()));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_max(const array& input, array& output, Stream stream) override {
|
||||||
|
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
|
using T = typename decltype(type_tag)::type;
|
||||||
|
all_reduce_impl<T>(input, output, stream, dt, ncclMax);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_min(const array& input, array& output, Stream stream) override {
|
||||||
|
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
|
using T = typename decltype(type_tag)::type;
|
||||||
|
all_reduce_impl<T>(input, output, stream, dt, ncclMin);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void all_reduce_impl(
|
||||||
|
const array& input,
|
||||||
|
array& output,
|
||||||
|
Stream stream,
|
||||||
|
ncclDataType_t dt,
|
||||||
|
ncclRedOp_t op) {
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(stream);
|
||||||
|
|
||||||
|
CHECK_NCCL(ncclAllReduce(
|
||||||
|
input.data<T>(),
|
||||||
|
output.data<T>(),
|
||||||
|
input.size(),
|
||||||
|
dt,
|
||||||
|
op,
|
||||||
|
comm_,
|
||||||
|
encoder.stream()
|
||||||
|
));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
int rank_, size_;
|
||||||
|
std::string initMethod_;
|
||||||
|
ncclUniqueId uniqueId_;
|
||||||
|
ncclComm_t comm_;
|
||||||
|
bool initialized_ = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
static std::string get_env_var_or_throw(const char* env_var_name) {
|
||||||
|
const char* value = std::getenv(env_var_name);
|
||||||
|
if (value == nullptr) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] Required environment variable '" << env_var_name
|
||||||
|
<< "' is not set. "
|
||||||
|
<< "Please set it before initializing the distributed backend.";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
return std::string(value);
|
||||||
|
}
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||||
|
std::string host = detail::get_env_var_or_throw("NCCL_HOST_IP");
|
||||||
|
std::string port = detail::get_env_var_or_throw("NCCL_PORT");
|
||||||
|
std::string rank_str = detail::get_env_var_or_throw("MLX_RANK");
|
||||||
|
std::string n_nodes_str = detail::get_env_var_or_throw("MLX_WORLD_SIZE");
|
||||||
|
|
||||||
|
int rank = std::stoi(rank_str);
|
||||||
|
int n_nodes = std::stoi(n_nodes_str);
|
||||||
|
std::string init_method = "tcp://" + host + ":" + port;
|
||||||
|
|
||||||
|
return std::make_shared<NCCLGroup>(rank, n_nodes, init_method);
|
||||||
|
}
|
||||||
|
} // namespace mlx::core::distributed::nccl
|
||||||
12
mlx/distributed/nccl/nccl.h
Normal file
12
mlx/distributed/nccl/nccl.h
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/distributed/distributed.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::nccl {
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
|
||||||
|
bool is_available();
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict = false);
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::nccl
|
||||||
20
mlx/distributed/nccl/no_nccl.cpp
Normal file
20
mlx/distributed/nccl/no_nccl.cpp
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/distributed/nccl/nccl.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::nccl {
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||||
|
if (strict) {
|
||||||
|
throw std::runtime_error("Cannot initialize nccl distributed backend.");
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::nccl
|
||||||
@@ -31,8 +31,7 @@ array all_sum(
|
|||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<AllReduce>(
|
std::make_shared<AllReduce>(to_stream(s), group, AllReduce::Sum),
|
||||||
to_stream(s, Device::cpu), group, AllReduce::Sum),
|
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user