mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Changed nccl reduction to be a parrt of cuda grapph
This commit is contained in:
@@ -1,15 +1,5 @@
|
|||||||
# Find the nccl libraries
|
# FindNCCL.cmake
|
||||||
#
|
# This module finds the NVIDIA NCCL library and its include directories.
|
||||||
# The following variables are optionally searched for defaults NCCL_ROOT_DIR:
|
|
||||||
# Base directory where all NCCL components are found NCCL_INCLUDE_DIR: Directory
|
|
||||||
# where NCCL header is found NCCL_LIB_DIR: Directory where NCCL library is found
|
|
||||||
#
|
|
||||||
# The following are set after configuration is done: NCCL_FOUND
|
|
||||||
# NCCL_INCLUDE_DIRS NCCL_LIBRARIES
|
|
||||||
#
|
|
||||||
# The path hints include CUDA_TOOLKIT_ROOT_DIR seeing as some folks install NCCL
|
|
||||||
# in the same location as the CUDA toolkit. See
|
|
||||||
# https://github.com/caffe2/caffe2/issues/1601
|
|
||||||
|
|
||||||
set(NCCL_ROOT_DIR
|
set(NCCL_ROOT_DIR
|
||||||
$ENV{NCCL_ROOT_DIR}
|
$ENV{NCCL_ROOT_DIR}
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||||
|
|||||||
87
mlx/backend/cuda/distributed.cu
Normal file
87
mlx/backend/cuda/distributed.cu
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/distributed/primitives.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
namespace distributed {
|
||||||
|
void AllReduce::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
// Here I assume for now that in is donatable and contiguous.
|
||||||
|
// TODO
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
|
auto& input = inputs[0];
|
||||||
|
auto& output = outputs[0];
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(stream());
|
||||||
|
output.set_data(allocator::malloc(output.nbytes()));
|
||||||
|
|
||||||
|
encoder.set_input_array(input);
|
||||||
|
encoder.set_output_array(output);
|
||||||
|
|
||||||
|
auto capture = encoder.capture_context();
|
||||||
|
auto& s = stream();
|
||||||
|
|
||||||
|
switch (reduce_type_) {
|
||||||
|
case Sum:
|
||||||
|
distributed::detail::all_sum(group(), input, output, s);
|
||||||
|
break;
|
||||||
|
case Max:
|
||||||
|
distributed::detail::all_max(group(), input, output, s);
|
||||||
|
break;
|
||||||
|
case Min:
|
||||||
|
distributed::detail::all_min(group(), input, output, s);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Only all reduce sum, max, and min are supported.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Send::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
// Here FOR NOW I assume that it is always row_contigious
|
||||||
|
// because not sure how to copy correctly
|
||||||
|
// TODO
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
|
distributed::detail::send(group(), inputs[0], dst_, stream());
|
||||||
|
outputs[0].copy_shared_buffer(inputs[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Recv::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() == 0);
|
||||||
|
assert(outputs.size() == 1);
|
||||||
|
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
|
||||||
|
distributed::detail::recv(group(), outputs[0], src_, stream());
|
||||||
|
}
|
||||||
|
|
||||||
|
void AllGather::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
// Here FOR NOW I assume that it is always row_contigious
|
||||||
|
// because not sure how to copy correctly
|
||||||
|
// TODO
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
|
auto& input = inputs[0];
|
||||||
|
auto& output = outputs[0];
|
||||||
|
|
||||||
|
output.copy_shared_buffer(input);
|
||||||
|
distributed::detail::all_gather(group(), input, output, stream());
|
||||||
|
}
|
||||||
|
}// namespace distributed
|
||||||
|
}
|
||||||
60
mlx/backend/cuda/iterators/strided_iterator.cuh
Normal file
60
mlx/backend/cuda/iterators/strided_iterator.cuh
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <thrust/iterator/iterator_adaptor.h>
|
||||||
|
#include <thrust/iterator/iterator_facade.h>
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
// RandomAccessIterator for strided access to array entries.
|
||||||
|
template <typename Iterator, typename Stride = int64_t>
|
||||||
|
class strided_iterator
|
||||||
|
: public thrust::
|
||||||
|
iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator> {
|
||||||
|
public:
|
||||||
|
using super_t =
|
||||||
|
thrust::iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator>;
|
||||||
|
|
||||||
|
using reference = typename super_t::reference;
|
||||||
|
using difference_type = typename super_t::difference_type;
|
||||||
|
|
||||||
|
__host__ __device__ strided_iterator(Iterator it, Stride stride)
|
||||||
|
: super_t(it), stride_(stride) {}
|
||||||
|
|
||||||
|
__host__ __device__ Stride stride() const {
|
||||||
|
return stride_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
friend class thrust::iterator_core_access;
|
||||||
|
|
||||||
|
__host__ __device__ bool equal(const strided_iterator& other) const {
|
||||||
|
return this->base() == other.base();
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ void advance(difference_type n) {
|
||||||
|
this->base_reference() += n * stride_;
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ void increment() {
|
||||||
|
this->base_reference() += stride_;
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ void decrement() {
|
||||||
|
this->base_reference() -= stride_;
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ difference_type
|
||||||
|
distance_to(const strided_iterator& other) const {
|
||||||
|
const difference_type dist = other.base() - this->base();
|
||||||
|
_CCCL_ASSERT(
|
||||||
|
dist % stride() == 0,
|
||||||
|
"Underlying iterator difference must be divisible by the stride");
|
||||||
|
return dist / stride();
|
||||||
|
}
|
||||||
|
|
||||||
|
Stride stride_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
||||||
@@ -6,6 +6,7 @@
|
|||||||
#include "mlx/backend/cuda/gemms/gemv.h"
|
#include "mlx/backend/cuda/gemms/gemv.h"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|||||||
@@ -17,28 +17,6 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace distributed {
|
|
||||||
void AllReduce::eval_gpu(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
std::vector<array>& outputs) {
|
|
||||||
// Here I assume for now that in is donatable and contiguous.
|
|
||||||
// TODO
|
|
||||||
|
|
||||||
auto& input = inputs[0];
|
|
||||||
auto& output = outputs[0];
|
|
||||||
|
|
||||||
output.copy_shared_buffer(input);
|
|
||||||
auto& s = stream();
|
|
||||||
switch (reduce_type_) {
|
|
||||||
case Sum:
|
|
||||||
distributed::detail::all_sum(group(), input, output, s);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace distributed
|
|
||||||
|
|
||||||
#define NO_GPU_MULTI(func) \
|
#define NO_GPU_MULTI(func) \
|
||||||
void func::eval_gpu( \
|
void func::eval_gpu( \
|
||||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||||
@@ -79,10 +57,11 @@ NO_GPU(ScaledDotProductAttention)
|
|||||||
NO_GPU_MULTI(CustomKernel)
|
NO_GPU_MULTI(CustomKernel)
|
||||||
} // namespace fast
|
} // namespace fast
|
||||||
|
|
||||||
namespace distributed {
|
// namespace distributed {
|
||||||
NO_GPU_MULTI(AllGather)
|
// NO_GPU_MULTI(AllReduce)
|
||||||
NO_GPU_MULTI(Send)
|
// NO_GPU_MULTI(AllGather)
|
||||||
NO_GPU_MULTI(Recv)
|
// NO_GPU_MULTI(Send)
|
||||||
} // namespace distributed
|
// NO_GPU_MULTI(Recv)
|
||||||
|
// } // namespace distributed
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
84
mlx/backend/cuda/reduce/segmented_reduce.cu
Normal file
84
mlx/backend/cuda/reduce/segmented_reduce.cu
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||||
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
|
|
||||||
|
#include <thrust/device_ptr.h>
|
||||||
|
#include <cub/device/device_reduce.cuh>
|
||||||
|
#include <cub/device/device_segmented_reduce.cuh>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
||||||
|
// Allocate temporary storage.
|
||||||
|
size_t size;
|
||||||
|
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...));
|
||||||
|
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||||
|
encoder.add_temporary(temp);
|
||||||
|
// Run op.
|
||||||
|
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data<void>(), size, args...));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
||||||
|
// Allocate temporary storage.
|
||||||
|
size_t size;
|
||||||
|
CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...));
|
||||||
|
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||||
|
encoder.add_temporary(temp);
|
||||||
|
// Run op.
|
||||||
|
CHECK_CUDA_ERROR(
|
||||||
|
cub::DeviceSegmentedReduce::Reduce(temp.data<void>(), size, args...));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MultiplyOp {
|
||||||
|
int factor;
|
||||||
|
__device__ int operator()(int i) {
|
||||||
|
return i * factor;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void segmented_reduce(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const ReductionPlan& plan) {
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
|
using InType = cuda_type_t<CTYPE>;
|
||||||
|
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||||
|
auto in_iter = cu::make_cast_iterator<OutType>(
|
||||||
|
thrust::device_pointer_cast(in.data<InType>()));
|
||||||
|
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>());
|
||||||
|
auto init = cu::ReduceInit<OP, InType>::value();
|
||||||
|
|
||||||
|
if (plan.type == ContiguousAllReduce) {
|
||||||
|
cub_all_reduce(
|
||||||
|
encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream);
|
||||||
|
} else if (plan.type == ContiguousReduce) {
|
||||||
|
auto offsets = thrust::make_transform_iterator(
|
||||||
|
thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()});
|
||||||
|
cub_segmented_reduce(
|
||||||
|
encoder,
|
||||||
|
in_iter,
|
||||||
|
out_ptr,
|
||||||
|
out.size(),
|
||||||
|
offsets,
|
||||||
|
offsets + 1,
|
||||||
|
OP(),
|
||||||
|
init,
|
||||||
|
stream);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Unsupported plan in segmented_reduce.");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/distributed_impl.h"
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
#include "mlx/distributed/mpi/mpi.h"
|
#include "mlx/distributed/mpi/mpi.h"
|
||||||
@@ -81,7 +82,7 @@ class EmptyGroup : public GroupImpl {
|
|||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
return mpi::is_available() || ring::is_available();
|
return mpi::is_available() || ring::is_available() || nccl::is_available();
|
||||||
}
|
}
|
||||||
|
|
||||||
int Group::rank() const {
|
int Group::rank() const {
|
||||||
|
|||||||
@@ -271,7 +271,12 @@ class NCCLGroup : public GroupImpl {
|
|||||||
void all_sum(const array& input, array& output, Stream stream) override {
|
void all_sum(const array& input, array& output, Stream stream) override {
|
||||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
using T = typename decltype(type_tag)::type;
|
using T = typename decltype(type_tag)::type;
|
||||||
detail::all_reduce_impl<T>(input, output, stream, dt, ncclSum);
|
all_reduce_impl<T>(
|
||||||
|
input,
|
||||||
|
output,
|
||||||
|
stream,
|
||||||
|
dt,
|
||||||
|
ncclSum);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -281,6 +286,8 @@ class NCCLGroup : public GroupImpl {
|
|||||||
|
|
||||||
void all_gather(const array& input, array& output, Stream stream) override {
|
void all_gather(const array& input, array& output, Stream stream) override {
|
||||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
|
auto& encoder = cu::get_command_encoder(stream);
|
||||||
|
|
||||||
using T = typename decltype(type_tag)::type;
|
using T = typename decltype(type_tag)::type;
|
||||||
CHECK_NCCL(ncclAllGather(
|
CHECK_NCCL(ncclAllGather(
|
||||||
input.data<T>(),
|
input.data<T>(),
|
||||||
@@ -288,12 +295,14 @@ class NCCLGroup : public GroupImpl {
|
|||||||
input.size(),
|
input.size(),
|
||||||
dt,
|
dt,
|
||||||
comm_,
|
comm_,
|
||||||
cu::get_stream(stream).last_cuda_stream()));
|
encoder.stream()));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void send(const array& input, int dst, Stream stream) override {
|
void send(const array& input, int dst, Stream stream) override {
|
||||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
|
auto& encoder = cu::get_command_encoder(stream);
|
||||||
|
|
||||||
using T = typename decltype(type_tag)::type;
|
using T = typename decltype(type_tag)::type;
|
||||||
CHECK_NCCL(ncclSend(
|
CHECK_NCCL(ncclSend(
|
||||||
input.data<T>(),
|
input.data<T>(),
|
||||||
@@ -301,20 +310,22 @@ class NCCLGroup : public GroupImpl {
|
|||||||
dt,
|
dt,
|
||||||
dst,
|
dst,
|
||||||
comm_,
|
comm_,
|
||||||
cu::get_stream(stream).last_cuda_stream()));
|
encoder.stream()));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void recv(array& output, int src, Stream stream) override {
|
void recv(array& output, int src, Stream stream) override {
|
||||||
detail::dispatch_dtype(output, [&](auto type_tag, ncclDataType_t dt) {
|
detail::dispatch_dtype(output, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
using T = typename decltype(type_tag)::type;
|
using T = typename decltype(type_tag)::type;
|
||||||
|
auto& encoder = cu::get_command_encoder(stream);
|
||||||
|
|
||||||
CHECK_NCCL(ncclRecv(
|
CHECK_NCCL(ncclRecv(
|
||||||
output.data<T>(),
|
output.data<T>(),
|
||||||
output.size(),
|
output.size(),
|
||||||
dt,
|
dt,
|
||||||
src,
|
src,
|
||||||
comm_,
|
comm_,
|
||||||
cu::get_stream(stream).last_cuda_stream()));
|
encoder.stream()));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -339,6 +350,9 @@ class NCCLGroup : public GroupImpl {
|
|||||||
Stream stream,
|
Stream stream,
|
||||||
ncclDataType_t dt,
|
ncclDataType_t dt,
|
||||||
ncclRedOp_t op) {
|
ncclRedOp_t op) {
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(stream);
|
||||||
|
|
||||||
CHECK_NCCL(ncclAllReduce(
|
CHECK_NCCL(ncclAllReduce(
|
||||||
input.data<T>(),
|
input.data<T>(),
|
||||||
output.data<T>(),
|
output.data<T>(),
|
||||||
@@ -346,7 +360,9 @@ class NCCLGroup : public GroupImpl {
|
|||||||
dt,
|
dt,
|
||||||
op,
|
op,
|
||||||
comm_,
|
comm_,
|
||||||
cu::get_stream(stream).last_cuda_stream()));
|
encoder.stream()
|
||||||
|
));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int rank_, size_;
|
int rank_, size_;
|
||||||
|
|||||||
Reference in New Issue
Block a user