mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Changed nccl reduction to be a parrt of cuda grapph
This commit is contained in:
@@ -1,15 +1,5 @@
|
||||
# Find the nccl libraries
|
||||
#
|
||||
# The following variables are optionally searched for defaults NCCL_ROOT_DIR:
|
||||
# Base directory where all NCCL components are found NCCL_INCLUDE_DIR: Directory
|
||||
# where NCCL header is found NCCL_LIB_DIR: Directory where NCCL library is found
|
||||
#
|
||||
# The following are set after configuration is done: NCCL_FOUND
|
||||
# NCCL_INCLUDE_DIRS NCCL_LIBRARIES
|
||||
#
|
||||
# The path hints include CUDA_TOOLKIT_ROOT_DIR seeing as some folks install NCCL
|
||||
# in the same location as the CUDA toolkit. See
|
||||
# https://github.com/caffe2/caffe2/issues/1601
|
||||
# FindNCCL.cmake
|
||||
# This module finds the NVIDIA NCCL library and its include directories.
|
||||
|
||||
set(NCCL_ROOT_DIR
|
||||
$ENV{NCCL_ROOT_DIR}
|
||||
|
||||
@@ -19,6 +19,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||
|
||||
87
mlx/backend/cuda/distributed.cu
Normal file
87
mlx/backend/cuda/distributed.cu
Normal file
@@ -0,0 +1,87 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
namespace distributed {
|
||||
void AllReduce::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
// Here I assume for now that in is donatable and contiguous.
|
||||
// TODO
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto& input = inputs[0];
|
||||
auto& output = outputs[0];
|
||||
|
||||
auto& encoder = cu::get_command_encoder(stream());
|
||||
output.set_data(allocator::malloc(output.nbytes()));
|
||||
|
||||
encoder.set_input_array(input);
|
||||
encoder.set_output_array(output);
|
||||
|
||||
auto capture = encoder.capture_context();
|
||||
auto& s = stream();
|
||||
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
distributed::detail::all_sum(group(), input, output, s);
|
||||
break;
|
||||
case Max:
|
||||
distributed::detail::all_max(group(), input, output, s);
|
||||
break;
|
||||
case Min:
|
||||
distributed::detail::all_min(group(), input, output, s);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"Only all reduce sum, max, and min are supported.");
|
||||
}
|
||||
}
|
||||
|
||||
void Send::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
// Here FOR NOW I assume that it is always row_contigious
|
||||
// because not sure how to copy correctly
|
||||
// TODO
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
distributed::detail::send(group(), inputs[0], dst_, stream());
|
||||
outputs[0].copy_shared_buffer(inputs[0]);
|
||||
}
|
||||
|
||||
void Recv::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 0);
|
||||
assert(outputs.size() == 1);
|
||||
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
|
||||
distributed::detail::recv(group(), outputs[0], src_, stream());
|
||||
}
|
||||
|
||||
void AllGather::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
// Here FOR NOW I assume that it is always row_contigious
|
||||
// because not sure how to copy correctly
|
||||
// TODO
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto& input = inputs[0];
|
||||
auto& output = outputs[0];
|
||||
|
||||
output.copy_shared_buffer(input);
|
||||
distributed::detail::all_gather(group(), input, output, stream());
|
||||
}
|
||||
}// namespace distributed
|
||||
}
|
||||
60
mlx/backend/cuda/iterators/strided_iterator.cuh
Normal file
60
mlx/backend/cuda/iterators/strided_iterator.cuh
Normal file
@@ -0,0 +1,60 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <thrust/iterator/iterator_adaptor.h>
|
||||
#include <thrust/iterator/iterator_facade.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// RandomAccessIterator for strided access to array entries.
|
||||
template <typename Iterator, typename Stride = int64_t>
|
||||
class strided_iterator
|
||||
: public thrust::
|
||||
iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator> {
|
||||
public:
|
||||
using super_t =
|
||||
thrust::iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator>;
|
||||
|
||||
using reference = typename super_t::reference;
|
||||
using difference_type = typename super_t::difference_type;
|
||||
|
||||
__host__ __device__ strided_iterator(Iterator it, Stride stride)
|
||||
: super_t(it), stride_(stride) {}
|
||||
|
||||
__host__ __device__ Stride stride() const {
|
||||
return stride_;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class thrust::iterator_core_access;
|
||||
|
||||
__host__ __device__ bool equal(const strided_iterator& other) const {
|
||||
return this->base() == other.base();
|
||||
}
|
||||
|
||||
__host__ __device__ void advance(difference_type n) {
|
||||
this->base_reference() += n * stride_;
|
||||
}
|
||||
|
||||
__host__ __device__ void increment() {
|
||||
this->base_reference() += stride_;
|
||||
}
|
||||
|
||||
__host__ __device__ void decrement() {
|
||||
this->base_reference() -= stride_;
|
||||
}
|
||||
|
||||
__host__ __device__ difference_type
|
||||
distance_to(const strided_iterator& other) const {
|
||||
const difference_type dist = other.base() - this->base();
|
||||
_CCCL_ASSERT(
|
||||
dist % stride() == 0,
|
||||
"Underlying iterator difference must be divisible by the stride");
|
||||
return dist / stride();
|
||||
}
|
||||
|
||||
Stride stride_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "mlx/backend/cuda/gemms/gemv.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <numeric>
|
||||
|
||||
@@ -17,28 +17,6 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace distributed {
|
||||
void AllReduce::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
// Here I assume for now that in is donatable and contiguous.
|
||||
// TODO
|
||||
|
||||
auto& input = inputs[0];
|
||||
auto& output = outputs[0];
|
||||
|
||||
output.copy_shared_buffer(input);
|
||||
auto& s = stream();
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
distributed::detail::all_sum(group(), input, output, s);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
}
|
||||
} // namespace distributed
|
||||
|
||||
#define NO_GPU_MULTI(func) \
|
||||
void func::eval_gpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
@@ -79,10 +57,11 @@ NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
} // namespace fast
|
||||
|
||||
namespace distributed {
|
||||
NO_GPU_MULTI(AllGather)
|
||||
NO_GPU_MULTI(Send)
|
||||
NO_GPU_MULTI(Recv)
|
||||
} // namespace distributed
|
||||
// namespace distributed {
|
||||
// NO_GPU_MULTI(AllReduce)
|
||||
// NO_GPU_MULTI(AllGather)
|
||||
// NO_GPU_MULTI(Send)
|
||||
// NO_GPU_MULTI(Recv)
|
||||
// } // namespace distributed
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
84
mlx/backend/cuda/reduce/segmented_reduce.cu
Normal file
84
mlx/backend/cuda/reduce/segmented_reduce.cu
Normal file
@@ -0,0 +1,84 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <cub/device/device_reduce.cuh>
|
||||
#include <cub/device/device_segmented_reduce.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename... Args>
|
||||
void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
||||
// Allocate temporary storage.
|
||||
size_t size;
|
||||
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...));
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
encoder.add_temporary(temp);
|
||||
// Run op.
|
||||
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data<void>(), size, args...));
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
||||
// Allocate temporary storage.
|
||||
size_t size;
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...));
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
encoder.add_temporary(temp);
|
||||
// Run op.
|
||||
CHECK_CUDA_ERROR(
|
||||
cub::DeviceSegmentedReduce::Reduce(temp.data<void>(), size, args...));
|
||||
}
|
||||
|
||||
struct MultiplyOp {
|
||||
int factor;
|
||||
__device__ int operator()(int i) {
|
||||
return i * factor;
|
||||
}
|
||||
};
|
||||
|
||||
void segmented_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||
auto in_iter = cu::make_cast_iterator<OutType>(
|
||||
thrust::device_pointer_cast(in.data<InType>()));
|
||||
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>());
|
||||
auto init = cu::ReduceInit<OP, InType>::value();
|
||||
|
||||
if (plan.type == ContiguousAllReduce) {
|
||||
cub_all_reduce(
|
||||
encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream);
|
||||
} else if (plan.type == ContiguousReduce) {
|
||||
auto offsets = thrust::make_transform_iterator(
|
||||
thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()});
|
||||
cub_segmented_reduce(
|
||||
encoder,
|
||||
in_iter,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
offsets,
|
||||
offsets + 1,
|
||||
OP(),
|
||||
init,
|
||||
stream);
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported plan in segmented_reduce.");
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include <iostream>
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/distributed/mpi/mpi.h"
|
||||
@@ -81,7 +82,7 @@ class EmptyGroup : public GroupImpl {
|
||||
} // namespace detail
|
||||
|
||||
bool is_available() {
|
||||
return mpi::is_available() || ring::is_available();
|
||||
return mpi::is_available() || ring::is_available() || nccl::is_available();
|
||||
}
|
||||
|
||||
int Group::rank() const {
|
||||
|
||||
@@ -271,7 +271,12 @@ class NCCLGroup : public GroupImpl {
|
||||
void all_sum(const array& input, array& output, Stream stream) override {
|
||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||
using T = typename decltype(type_tag)::type;
|
||||
detail::all_reduce_impl<T>(input, output, stream, dt, ncclSum);
|
||||
all_reduce_impl<T>(
|
||||
input,
|
||||
output,
|
||||
stream,
|
||||
dt,
|
||||
ncclSum);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -281,6 +286,8 @@ class NCCLGroup : public GroupImpl {
|
||||
|
||||
void all_gather(const array& input, array& output, Stream stream) override {
|
||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||
auto& encoder = cu::get_command_encoder(stream);
|
||||
|
||||
using T = typename decltype(type_tag)::type;
|
||||
CHECK_NCCL(ncclAllGather(
|
||||
input.data<T>(),
|
||||
@@ -288,12 +295,14 @@ class NCCLGroup : public GroupImpl {
|
||||
input.size(),
|
||||
dt,
|
||||
comm_,
|
||||
cu::get_stream(stream).last_cuda_stream()));
|
||||
encoder.stream()));
|
||||
});
|
||||
}
|
||||
|
||||
void send(const array& input, int dst, Stream stream) override {
|
||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||
auto& encoder = cu::get_command_encoder(stream);
|
||||
|
||||
using T = typename decltype(type_tag)::type;
|
||||
CHECK_NCCL(ncclSend(
|
||||
input.data<T>(),
|
||||
@@ -301,20 +310,22 @@ class NCCLGroup : public GroupImpl {
|
||||
dt,
|
||||
dst,
|
||||
comm_,
|
||||
cu::get_stream(stream).last_cuda_stream()));
|
||||
encoder.stream()));
|
||||
});
|
||||
}
|
||||
|
||||
void recv(array& output, int src, Stream stream) override {
|
||||
detail::dispatch_dtype(output, [&](auto type_tag, ncclDataType_t dt) {
|
||||
using T = typename decltype(type_tag)::type;
|
||||
auto& encoder = cu::get_command_encoder(stream);
|
||||
|
||||
CHECK_NCCL(ncclRecv(
|
||||
output.data<T>(),
|
||||
output.size(),
|
||||
dt,
|
||||
src,
|
||||
comm_,
|
||||
cu::get_stream(stream).last_cuda_stream()));
|
||||
encoder.stream()));
|
||||
});
|
||||
}
|
||||
|
||||
@@ -339,6 +350,9 @@ class NCCLGroup : public GroupImpl {
|
||||
Stream stream,
|
||||
ncclDataType_t dt,
|
||||
ncclRedOp_t op) {
|
||||
|
||||
auto& encoder = cu::get_command_encoder(stream);
|
||||
|
||||
CHECK_NCCL(ncclAllReduce(
|
||||
input.data<T>(),
|
||||
output.data<T>(),
|
||||
@@ -346,7 +360,9 @@ class NCCLGroup : public GroupImpl {
|
||||
dt,
|
||||
op,
|
||||
comm_,
|
||||
cu::get_stream(stream).last_cuda_stream()));
|
||||
encoder.stream()
|
||||
));
|
||||
|
||||
}
|
||||
|
||||
int rank_, size_;
|
||||
|
||||
Reference in New Issue
Block a user