mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Compare commits
7 Commits
1963714f8d
...
7ea2252476
Author | SHA1 | Date | |
---|---|---|---|
![]() |
7ea2252476 | ||
![]() |
e6ae350999 | ||
![]() |
b8022c578a | ||
![]() |
70f2baf39f | ||
![]() |
71a47bc10d | ||
![]() |
e9fbdd20fb | ||
![]() |
f15a127900 |
64
cmake/FindNCCL.cmake
Normal file
64
cmake/FindNCCL.cmake
Normal file
@ -0,0 +1,64 @@
|
||||
# Find the nccl libraries
|
||||
#
|
||||
# The following variables are optionally searched for defaults NCCL_ROOT_DIR:
|
||||
# Base directory where all NCCL components are found NCCL_INCLUDE_DIR: Directory
|
||||
# where NCCL header is found NCCL_LIB_DIR: Directory where NCCL library is found
|
||||
#
|
||||
# The following are set after configuration is done: NCCL_FOUND
|
||||
# NCCL_INCLUDE_DIRS NCCL_LIBRARIES
|
||||
#
|
||||
# The path hints include CUDA_TOOLKIT_ROOT_DIR seeing as some folks install NCCL
|
||||
# in the same location as the CUDA toolkit. See
|
||||
# https://github.com/caffe2/caffe2/issues/1601
|
||||
|
||||
set(NCCL_ROOT_DIR
|
||||
$ENV{NCCL_ROOT_DIR}
|
||||
CACHE PATH "Folder contains NVIDIA NCCL")
|
||||
|
||||
find_path(
|
||||
NCCL_INCLUDE_DIRS
|
||||
NAMES nccl.h
|
||||
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/include)
|
||||
|
||||
if($ENV{USE_STATIC_NCCL})
|
||||
message(
|
||||
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
|
||||
set(NCCL_LIBNAME "libnccl_static.a")
|
||||
else()
|
||||
set(NCCL_LIBNAME "nccl")
|
||||
endif()
|
||||
|
||||
find_library(
|
||||
NCCL_LIBRARIES
|
||||
NAMES ${NCCL_LIBNAME}
|
||||
HINTS ${NCCL_LIB_DIR}
|
||||
${NCCL_ROOT_DIR}
|
||||
${NCCL_ROOT_DIR}/lib
|
||||
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
|
||||
${NCCL_ROOT_DIR}/lib64
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
|
||||
|
||||
include(FindPackageHandleStandardArgs)
|
||||
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
|
||||
NCCL_LIBRARIES)
|
||||
|
||||
if(NCCL_FOUND)
|
||||
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
|
||||
message(
|
||||
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
|
||||
file(
|
||||
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
|
||||
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
|
||||
LIMIT_COUNT 1)
|
||||
if(NCCL_MAJOR_VERSION_DEFINED)
|
||||
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
|
||||
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
|
||||
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
|
||||
endif()
|
||||
message(
|
||||
STATUS
|
||||
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
|
||||
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
|
||||
endif()
|
@ -8,6 +8,7 @@ target_sources(
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
|
||||
|
@ -125,13 +125,12 @@ constexpr bool supports_binary_op() {
|
||||
template <typename Op>
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
array& out,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() > 1);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
@ -146,7 +145,6 @@ void binary_op_gpu_inplace(
|
||||
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||
using InType = cuda_type_t<CTYPE_IN>;
|
||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
@ -219,20 +217,6 @@ void binary_op_gpu_inplace(
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt);
|
||||
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
@ -243,8 +227,7 @@ void binary_op_gpu(
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
std::vector<array> outputs{out};
|
||||
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
||||
binary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||
}
|
||||
|
||||
#define BINARY_GPU(func) \
|
||||
@ -254,14 +237,6 @@ void binary_op_gpu(
|
||||
binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
|
||||
}
|
||||
|
||||
#define BINARY_GPU_MULTI(func) \
|
||||
void func::eval_gpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
||||
auto& s = outputs[0].primitive().stream(); \
|
||||
binary_op_gpu<cu::func>(inputs, outputs, get_primitive_string(this), s); \
|
||||
}
|
||||
|
||||
BINARY_GPU(Add)
|
||||
BINARY_GPU(ArcTan2)
|
||||
BINARY_GPU(Divide)
|
||||
|
248
mlx/backend/cuda/binary_two.cu
Normal file
248
mlx/backend/cuda/binary_two.cu
Normal file
@ -0,0 +1,248 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void
|
||||
binary_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto out = Op{}(a[0], b[0]);
|
||||
out_a[0] = out[0];
|
||||
out_b[0] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void
|
||||
binary_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto out = Op{}(a[0], b[index]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void
|
||||
binary_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto out = Op{}(a[index], b[0]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void
|
||||
binary_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto out = Op{}(a[index], b[index]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
|
||||
__global__ void binary_g_nd(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out_a,
|
||||
Out* out_b,
|
||||
IdxT size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
|
||||
index, shape.data(), a_strides.data(), b_strides.data());
|
||||
auto out = Op{}(a[a_idx], b[b_idx]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_g(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out_a,
|
||||
Out* out_b,
|
||||
IdxT size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides a_strides,
|
||||
const __grid_constant__ Strides b_strides,
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc_4d(
|
||||
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
||||
auto out = Op{}(a[a_idx], b[b_idx]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out>
|
||||
constexpr bool supports_binary_op() {
|
||||
if (std::is_same_v<Op, DivMod>) {
|
||||
return std::is_same_v<In, Out> &&
|
||||
(std::is_integral_v<Out> || is_floating_v<Out>);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() > 1);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
auto& out_a = outputs[0];
|
||||
auto& out_b = outputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out_a, bopt);
|
||||
set_binary_op_output_data(a, b, out_b, bopt);
|
||||
|
||||
if (out_a.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out_a);
|
||||
encoder.set_output_array(out_b);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
|
||||
MLX_SWITCH_ALL_TYPES(out_a.dtype(), CTYPE_OUT, {
|
||||
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||
using InType = cuda_type_t<CTYPE_IN>;
|
||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out_a);
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
bool large = a.data_size() > INT32_MAX ||
|
||||
b.data_size() > INT32_MAX || out_a.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
auto kernel =
|
||||
&cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out_a, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
out_a.size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(a_strides),
|
||||
const_param<NDIM>(b_strides));
|
||||
});
|
||||
} else {
|
||||
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out_a, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
out_a.size(),
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
MLX_SWITCH_BOOL(out_a.data_size() > UINT32_MAX, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
|
||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
|
||||
} else if (bopt == BinaryOpType::VectorVector) {
|
||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel,
|
||||
out_a.data_size(),
|
||||
out_a.shape(),
|
||||
out_a.strides(),
|
||||
LARGE);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
out_a.data_size());
|
||||
});
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Can not do binary op {} on inputs of {} with result of {}.",
|
||||
op,
|
||||
dtype_to_string(a.dtype()),
|
||||
dtype_to_string(out_a.dtype())));
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt);
|
||||
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
||||
}
|
||||
|
||||
void DivMod::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("DivMod::eval_gpu");
|
||||
auto& s = outputs[0].primitive().stream();
|
||||
binary_op_gpu<cu::DivMod>(inputs, outputs, get_primitive_string(this), s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -22,7 +22,7 @@ struct FloorDivide {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return x / y;
|
||||
} else {
|
||||
return trunc(x / y);
|
||||
return truncf(x / y);
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -132,7 +132,7 @@ struct LogAddExp {
|
||||
cuda::std::numeric_limits<float>::quiet_NaN(),
|
||||
cuda::std::numeric_limits<float>::quiet_NaN()};
|
||||
}
|
||||
constexpr float inf = cuda::std::numeric_limits<float>::infinity();
|
||||
float inf = cuda::std::numeric_limits<float>::infinity();
|
||||
auto maxval = x > y ? x : y;
|
||||
auto minval = x < y ? x : y;
|
||||
if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf)
|
||||
|
@ -5,7 +5,7 @@
|
||||
#pragma once
|
||||
|
||||
// The maximum dimensions of shape/strides passed as kernel parameters.
|
||||
#define MAX_NDIM 8
|
||||
#define MAX_NDIM 10
|
||||
|
||||
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
|
||||
// warpSize variable exists, using it would prevent compile-time optimizations.
|
||||
|
@ -54,6 +54,28 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace distributed {
|
||||
void AllReduce::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
// Here I assume for now that in is donatable and contiguous.
|
||||
// TODO
|
||||
|
||||
auto& input = inputs[0];
|
||||
auto& output = outputs[0];
|
||||
|
||||
output.copy_shared_buffer(input);
|
||||
auto& s = stream();
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
distributed::detail::all_sum(group(), input, output, s);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
}
|
||||
} // namespace distributed
|
||||
|
||||
#define NO_GPU_MULTI(func) \
|
||||
void func::eval_gpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
@ -71,10 +93,8 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
||||
throw std::runtime_error(#func " has no CUDA implementation."); \
|
||||
}
|
||||
|
||||
NO_GPU(ArgPartition)
|
||||
NO_GPU(BlockMaskedMM)
|
||||
NO_GPU(Convolution)
|
||||
NO_GPU_MULTI(DivMod)
|
||||
NO_GPU(DynamicSlice)
|
||||
NO_GPU(DynamicSliceUpdate)
|
||||
NO_GPU(FFT)
|
||||
@ -83,7 +103,6 @@ NO_GPU(GatherQMM)
|
||||
NO_GPU(Hadamard)
|
||||
NO_GPU(Load)
|
||||
NO_GPU_MULTI(LUF)
|
||||
NO_GPU(Partition)
|
||||
NO_GPU_MULTI(QRF)
|
||||
NO_GPU(QuantizedMatmul)
|
||||
NO_GPU(Scan)
|
||||
@ -100,7 +119,6 @@ NO_GPU_MULTI(CustomKernel)
|
||||
} // namespace fast
|
||||
|
||||
namespace distributed {
|
||||
NO_GPU_MULTI(AllReduce)
|
||||
NO_GPU_MULTI(AllGather)
|
||||
NO_GPU_MULTI(Send)
|
||||
NO_GPU_MULTI(Recv)
|
||||
|
@ -86,7 +86,6 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
axis += in.ndim();
|
||||
}
|
||||
int nsort = in.shape(axis);
|
||||
int nsegments = in.data_size() / nsort;
|
||||
int last_dim = in.ndim() - 1;
|
||||
|
||||
// If we are not sorting the innermost dimension of a contiguous array,
|
||||
@ -100,7 +99,11 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
|
||||
encoder.add_temporary(out);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
@ -134,7 +137,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
indices.data<uint32_t>(),
|
||||
out.data<uint32_t>(),
|
||||
in.data_size(),
|
||||
nsegments,
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
stream);
|
||||
@ -144,7 +147,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
in.data<Type>(),
|
||||
out.data<Type>(),
|
||||
in.data_size(),
|
||||
nsegments,
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
stream);
|
||||
@ -177,4 +180,14 @@ void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
gpu_sort(stream(), inputs[0], out, axis_, false);
|
||||
}
|
||||
|
||||
void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("ArgPartition::eval_gpu");
|
||||
gpu_sort(stream(), inputs[0], out, axis_, true);
|
||||
}
|
||||
|
||||
void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Partition::eval_gpu");
|
||||
gpu_sort(stream(), inputs[0], out, axis_, false);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -6,3 +6,4 @@ target_sources(
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/distributed/mpi/mpi.h"
|
||||
#include "mlx/distributed/nccl/nccl.h"
|
||||
#include "mlx/distributed/ring/ring.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
@ -111,6 +112,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
||||
group = mpi::init(strict);
|
||||
} else if (bk == "ring") {
|
||||
group = ring::init(strict);
|
||||
} else if (bk == "nccl") {
|
||||
group = nccl::init(strict);
|
||||
} else if (bk == "any") {
|
||||
group = ring::init(false);
|
||||
bk_ = "ring";
|
||||
|
8
mlx/distributed/nccl/CMakeLists.txt
Normal file
8
mlx/distributed/nccl/CMakeLists.txt
Normal file
@ -0,0 +1,8 @@
|
||||
if(MLX_BUILD_CUDA)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nccl.cpp)
|
||||
find_package(NCCL REQUIRED)
|
||||
target_link_libraries(mlx PRIVATE ${NCCL_LIBRARIES})
|
||||
target_include_directories(mlx PRIVATE ${NCCL_INCLUDE_DIRS})
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_nccl.cpp)
|
||||
endif()
|
382
mlx/distributed/nccl/nccl.cpp
Normal file
382
mlx/distributed/nccl/nccl.cpp
Normal file
@ -0,0 +1,382 @@
|
||||
#include <arpa/inet.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <nccl.h>
|
||||
#include <netdb.h>
|
||||
#include <sys/socket.h>
|
||||
#include <unistd.h>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <mutex>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
|
||||
namespace mlx::core::distributed::nccl {
|
||||
|
||||
#define CHECK_CUDA(cmd) \
|
||||
do { \
|
||||
cudaError_t e = cmd; \
|
||||
if (e != cudaSuccess) { \
|
||||
fprintf( \
|
||||
stderr, \
|
||||
"CUDA error %s:%d '%s'\n", \
|
||||
__FILE__, \
|
||||
__LINE__, \
|
||||
cudaGetErrorString(e)); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define CHECK_NCCL(cmd) \
|
||||
do { \
|
||||
ncclResult_t r = cmd; \
|
||||
if (r != ncclSuccess) { \
|
||||
fprintf( \
|
||||
stderr, \
|
||||
"NCCL error %s:%d '%s'\n", \
|
||||
__FILE__, \
|
||||
__LINE__, \
|
||||
ncclGetErrorString(r)); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace detail {
|
||||
|
||||
inline void sendAll(int sock, const void* buf, size_t len) {
|
||||
const char* ptr = reinterpret_cast<const char*>(buf);
|
||||
while (len > 0) {
|
||||
ssize_t sent = send(sock, ptr, len, 0);
|
||||
if (sent <= 0) {
|
||||
perror("send");
|
||||
exit(1);
|
||||
}
|
||||
ptr += sent;
|
||||
len -= sent;
|
||||
}
|
||||
}
|
||||
|
||||
inline void recvAll(int sock, void* buf, size_t len) {
|
||||
char* ptr = reinterpret_cast<char*>(buf);
|
||||
while (len > 0) {
|
||||
ssize_t rec = recv(sock, ptr, len, 0);
|
||||
if (rec <= 0) {
|
||||
perror("recv");
|
||||
exit(1);
|
||||
}
|
||||
ptr += rec;
|
||||
len -= rec;
|
||||
}
|
||||
}
|
||||
|
||||
inline void bootstrap_unique_id(
|
||||
ncclUniqueId& id,
|
||||
int rank,
|
||||
int size,
|
||||
const std::string& initMethod) {
|
||||
|
||||
if (initMethod.rfind("tcp://", 0) != 0)
|
||||
throw;
|
||||
auto hostport = initMethod.substr(6);
|
||||
auto colon = hostport.find(':');
|
||||
std::string host = hostport.substr(0, colon);
|
||||
int port = std::stoi(hostport.substr(colon + 1));
|
||||
|
||||
if (rank == 0) {
|
||||
CHECK_NCCL(ncclGetUniqueId(&id));
|
||||
|
||||
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
|
||||
if (sock < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[nccl] Couldn't create socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
sockaddr_in serv = {};
|
||||
serv.sin_family = AF_INET;
|
||||
serv.sin_addr.s_addr = htonl(INADDR_ANY);
|
||||
serv.sin_port = htons(port);
|
||||
|
||||
int reuse = 1;
|
||||
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[nccl] setsockopt() failed: " << strerror(errno);
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
if (bind(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[nccl] bind() failed: " << strerror(errno);
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
if (listen(sock, size - 1) < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[nccl] listen() failed: " << strerror(errno);
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
for (int peer = 1; peer < size; ++peer) {
|
||||
int conn = accept(sock, nullptr, nullptr);
|
||||
if (conn < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[nccl] accept() failed: " << strerror(errno);
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
sendAll(conn, &id, sizeof(id));
|
||||
close(conn);
|
||||
}
|
||||
close(sock);
|
||||
|
||||
} else {
|
||||
// Here just wanted to make show that rank 0 has enough time to bind
|
||||
// so we will retry to connect until max attempts
|
||||
|
||||
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sock < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[nccl] socket() failed: " << strerror(errno);
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
hostent* he = gethostbyname(host.c_str());
|
||||
if (!he) {
|
||||
throw std::runtime_error("[nccl] lookup failed for host: " + host);
|
||||
}
|
||||
sockaddr_in serv = {};
|
||||
serv.sin_family = AF_INET;
|
||||
memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length);
|
||||
serv.sin_port = htons(port);
|
||||
|
||||
const int max_retries = 30;
|
||||
int attempt = 0;
|
||||
bool connected = false;
|
||||
|
||||
for (attempt = 0; attempt < max_retries; ++attempt) {
|
||||
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
|
||||
0) {
|
||||
connected = true;
|
||||
std::cout << "[Rank " << rank << "] Connected successfully on attempt "
|
||||
<< attempt + 1 << std::endl;
|
||||
break;
|
||||
}
|
||||
if (errno != ECONNREFUSED) {
|
||||
break;
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
||||
}
|
||||
|
||||
if (!connected) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Rank " << rank << "] connect() failed after " << attempt
|
||||
<< " retries: " << strerror(errno);
|
||||
close(sock);
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
recvAll(sock, &id, sizeof(id));
|
||||
close(sock);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct type_identity {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <typename F>
|
||||
void dispatch_dtype(const array& arr, F&& f) {
|
||||
switch (arr.dtype()) {
|
||||
case bool_:
|
||||
throw std::invalid_argument("[nccl] Boolean arrays not supported");
|
||||
case int8:
|
||||
f(type_identity<int8_t>{}, ncclChar);
|
||||
break;
|
||||
case uint8:
|
||||
f(type_identity<uint8_t>{}, ncclUint8);
|
||||
break;
|
||||
case int32:
|
||||
f(type_identity<int32_t>{}, ncclInt);
|
||||
break;
|
||||
case uint32:
|
||||
f(type_identity<uint32_t>{}, ncclUint32);
|
||||
break;
|
||||
case int64:
|
||||
f(type_identity<int64_t>{}, ncclInt64);
|
||||
break;
|
||||
case uint64:
|
||||
f(type_identity<uint64_t>{}, ncclUint64);
|
||||
break;
|
||||
case float16:
|
||||
f(type_identity<float16_t>{}, ncclHalf);
|
||||
break;
|
||||
case bfloat16:
|
||||
f(type_identity<bfloat16_t>{}, ncclBfloat16);
|
||||
break;
|
||||
case float32:
|
||||
f(type_identity<float>{}, ncclFloat);
|
||||
break;
|
||||
case float64:
|
||||
f(type_identity<double>{}, ncclDouble);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument("[nccl] Unknown or unsupported dtype");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||
class NCCLGroup : public GroupImpl {
|
||||
public:
|
||||
NCCLGroup(int worldRank, int worldSize, const std::string initMethod)
|
||||
: rank_(worldRank),
|
||||
size_(worldSize),
|
||||
comm_(nullptr),
|
||||
initMethod_(initMethod) {
|
||||
if (initialized_)
|
||||
return;
|
||||
int ndev;
|
||||
CHECK_CUDA(cudaGetDeviceCount(&ndev));
|
||||
CHECK_CUDA(cudaSetDevice(rank_ % ndev));
|
||||
CHECK_CUDA(cudaStreamCreate(&stream_));
|
||||
|
||||
detail::bootstrapUniqueId(uniqueId_, rank_, size_, initMethod_);
|
||||
CHECK_NCCL(ncclCommInitRank(&comm_, size_, uniqueId_, rank_));
|
||||
initialized_ = true;
|
||||
}
|
||||
|
||||
~NCCLGroup() {
|
||||
ncclCommDestroy(comm_);
|
||||
ncclGroupEnd();
|
||||
cudaStreamDestroy(stream_);
|
||||
initialized_ = false;
|
||||
}
|
||||
|
||||
int rank() override {
|
||||
return rank_;
|
||||
}
|
||||
|
||||
int size() override {
|
||||
return size_;
|
||||
}
|
||||
|
||||
void all_sum(const array& input, array& output, Stream stream) override {
|
||||
if (input.size() != output.size()) {
|
||||
throw std::runtime_error(
|
||||
"[nccl] Input and output arrays must have the same size.");
|
||||
}
|
||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||
using T = typename decltype(type_tag)::type;
|
||||
all_reduce_impl<T>(input, output, stream, dt, ncclSum);
|
||||
});
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||
throw std::runtime_error("[nccl] Group split not supported.");
|
||||
}
|
||||
|
||||
void all_gather(const array& input, array& output, Stream stream) override {
|
||||
if (input.size() != output.size() / size_) {
|
||||
throw std::runtime_error(
|
||||
"[nccl] Input size must match output size divided by group size.");
|
||||
}
|
||||
}
|
||||
|
||||
void send(const array& input, int dst, Stream stream) override {
|
||||
if (input.size() == 0) {
|
||||
return; // Nothing to send
|
||||
}
|
||||
}
|
||||
|
||||
void recv(array& output, int src, Stream stream) override {
|
||||
if (output.size() == 0) {
|
||||
return; // Nothing to receive
|
||||
}
|
||||
}
|
||||
|
||||
void all_max(const array& input, array& output, Stream stream) override {
|
||||
if (input.size() != output.size()) {
|
||||
throw std::runtime_error(
|
||||
"[nccl] Input and output arrays must have the same size.");
|
||||
}
|
||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||
using T = typename decltype(type_tag)::type;
|
||||
all_reduce_impl<T>(input, output, stream, dt, ncclMax);
|
||||
});
|
||||
}
|
||||
|
||||
void all_min(const array& input, array& output, Stream stream) override {
|
||||
if (input.size() != output.size()) {
|
||||
throw std::runtime_error(
|
||||
"[nccl] Input and output arrays must have the same size.");
|
||||
}
|
||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||
using T = typename decltype(type_tag)::type;
|
||||
all_reduce_impl<T>(input, output, stream, dt, ncclMin);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void all_reduce_impl(
|
||||
const array& input,
|
||||
array& output,
|
||||
Stream stream,
|
||||
ncclDataType_t dt,
|
||||
ncclRedOp_t op) {
|
||||
|
||||
CHECK_NCCL(ncclAllReduce(
|
||||
input.data<T>(),
|
||||
output.data<T>(),
|
||||
input.size(),
|
||||
dt,
|
||||
op,
|
||||
comm_,
|
||||
stream_));
|
||||
cudaStreamSynchronize(stream_);
|
||||
}
|
||||
|
||||
int rank_, size_;
|
||||
std::string initMethod_;
|
||||
ncclUniqueId uniqueId_;
|
||||
ncclComm_t comm_;
|
||||
cudaStream_t stream_;
|
||||
bool initialized_ = false;
|
||||
};
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
static std::string get_env_var_or_throw(const char* env_var_name) {
|
||||
const char* value = std::getenv(env_var_name);
|
||||
if (value == nullptr) {
|
||||
std::ostringstream msg;
|
||||
msg << "[nccl] Required environment variable '" << env_var_name
|
||||
<< "' is not set. "
|
||||
<< "Please set it before initializing the distributed backend.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
return std::string(value);
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||
std::string host = detail::get_env_var_or_throw("NCCL_HOST_IP");
|
||||
std::string port = detail::get_env_var_or_throw("NCCL_PORT");
|
||||
std::string rank_str = detail::get_env_var_or_throw("MLX_RANK");
|
||||
std::string n_nodes_str = detail::get_env_var_or_throw("MLX_WORLD_SIZE");
|
||||
|
||||
int rank = std::stoi(rank_str);
|
||||
int n_nodes = std::stoi(n_nodes_str);
|
||||
std::string init_method = "tcp://" + host + ":" + port;
|
||||
|
||||
return std::make_shared<NCCLGroup>(rank, n_nodes, init_method);
|
||||
}
|
||||
} // namespace mlx::core::distributed::nccl
|
12
mlx/distributed/nccl/nccl.h
Normal file
12
mlx/distributed/nccl/nccl.h
Normal file
@ -0,0 +1,12 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
|
||||
namespace mlx::core::distributed::nccl {
|
||||
|
||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||
|
||||
bool is_available();
|
||||
std::shared_ptr<GroupImpl> init(bool strict = false);
|
||||
|
||||
} // namespace mlx::core::distributed::nccl
|
20
mlx/distributed/nccl/no_nccl.cpp
Normal file
20
mlx/distributed/nccl/no_nccl.cpp
Normal file
@ -0,0 +1,20 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/distributed/nccl/nccl.h"
|
||||
|
||||
namespace mlx::core::distributed::nccl {
|
||||
|
||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||
|
||||
bool is_available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||
if (strict) {
|
||||
throw std::runtime_error("Cannot initialize nccl distributed backend.");
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed::nccl
|
@ -31,8 +31,7 @@ array all_sum(
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_shared<AllReduce>(
|
||||
to_stream(s, Device::cpu), group, AllReduce::Sum),
|
||||
std::make_shared<AllReduce>(to_stream(s), group, AllReduce::Sum),
|
||||
{x});
|
||||
}
|
||||
|
||||
|
@ -1,10 +1,8 @@
|
||||
cuda_skip = {
|
||||
"TestArray.test_api",
|
||||
"TestAutograd.test_update_state",
|
||||
"TestBF16.test_arg_reduction_ops",
|
||||
"TestBF16.test_reduction_ops",
|
||||
"TestBlas.test_complex_gemm",
|
||||
"TestCompile.test_compile_dynamic_dims",
|
||||
"TestEinsum.test_ellipses",
|
||||
"TestEinsum.test_opt_einsum_test_cases",
|
||||
"TestLoad.test_load_f8_e4m3",
|
||||
@ -14,24 +12,14 @@ cuda_skip = {
|
||||
"TestLayers.test_quantized_embedding",
|
||||
"TestLayers.test_sin_pe",
|
||||
"TestLayers.test_upsample",
|
||||
"TestOps.test_array_equal",
|
||||
"TestOps.test_complex_ops",
|
||||
"TestOps.test_dynamic_slicing",
|
||||
"TestOps.test_softmax",
|
||||
"TestOps.test_sort",
|
||||
"TestOps.test_tile",
|
||||
"TestReduce.test_axis_permutation_sums",
|
||||
"TestReduce.test_dtypes",
|
||||
"TestReduce.test_expand_sums",
|
||||
"TestReduce.test_many_reduction_axes",
|
||||
"TestUpsample.test_torch_upsample",
|
||||
# DivMod NYI
|
||||
"TestOps.test_divmod",
|
||||
"TestEval.test_multi_output_eval_during_transform",
|
||||
# Partition NYI
|
||||
"TestAutograd.test_topk_grad",
|
||||
"TestOps.test_argpartition",
|
||||
"TestOps.test_partition",
|
||||
# Block masked matmul NYI
|
||||
"TestBlas.test_block_masked_matmul",
|
||||
# Gather matmul NYI
|
||||
|
Loading…
Reference in New Issue
Block a user