diff --git a/cmake/FindNCCL.cmake b/cmake/FindNCCL.cmake index 7f8791476..917640f0d 100644 --- a/cmake/FindNCCL.cmake +++ b/cmake/FindNCCL.cmake @@ -1,15 +1,5 @@ -# Find the nccl libraries -# -# The following variables are optionally searched for defaults NCCL_ROOT_DIR: -# Base directory where all NCCL components are found NCCL_INCLUDE_DIR: Directory -# where NCCL header is found NCCL_LIB_DIR: Directory where NCCL library is found -# -# The following are set after configuration is done: NCCL_FOUND -# NCCL_INCLUDE_DIRS NCCL_LIBRARIES -# -# The path hints include CUDA_TOOLKIT_ROOT_DIR seeing as some folks install NCCL -# in the same location as the CUDA toolkit. See -# https://github.com/caffe2/caffe2/issues/1601 +# FindNCCL.cmake +# This module finds the NVIDIA NCCL library and its include directories. set(NCCL_ROOT_DIR $ENV{NCCL_ROOT_DIR} diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 8c1b999e9..5e0f970da 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -19,6 +19,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp diff --git a/mlx/backend/cuda/distributed.cu b/mlx/backend/cuda/distributed.cu new file mode 100644 index 000000000..df0fe4539 --- /dev/null +++ b/mlx/backend/cuda/distributed.cu @@ -0,0 +1,87 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/distributed/primitives.h" +#include "mlx/primitives.h" +#include "mlx/backend/cuda/kernel_utils.cuh" + + +#include + +namespace mlx::core { + namespace distributed { + void AllReduce::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + // Here I assume for now that in is donatable and contiguous. + // TODO + assert(inputs.size() == 1); + assert(outputs.size() == 1); + + auto& input = inputs[0]; + auto& output = outputs[0]; + + auto& encoder = cu::get_command_encoder(stream()); + output.set_data(allocator::malloc(output.nbytes())); + + encoder.set_input_array(input); + encoder.set_output_array(output); + + auto capture = encoder.capture_context(); + auto& s = stream(); + + switch (reduce_type_) { + case Sum: + distributed::detail::all_sum(group(), input, output, s); + break; + case Max: + distributed::detail::all_max(group(), input, output, s); + break; + case Min: + distributed::detail::all_min(group(), input, output, s); + break; + default: + throw std::runtime_error( + "Only all reduce sum, max, and min are supported."); + } + } + + void Send::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + // Here FOR NOW I assume that it is always row_contigious + // because not sure how to copy correctly + // TODO + assert(inputs.size() == 1); + assert(outputs.size() == 1); + + distributed::detail::send(group(), inputs[0], dst_, stream()); + outputs[0].copy_shared_buffer(inputs[0]); + } + + void Recv::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 0); + assert(outputs.size() == 1); + outputs[0].set_data(allocator::malloc(outputs[0].nbytes())); + distributed::detail::recv(group(), outputs[0], src_, stream()); + } + + void AllGather::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + // Here FOR NOW I assume that it is always row_contigious + // because not sure how to copy correctly + // TODO + assert(inputs.size() == 1); + assert(outputs.size() == 1); + + auto& input = inputs[0]; + auto& output = outputs[0]; + + output.copy_shared_buffer(input); + distributed::detail::all_gather(group(), input, output, stream()); + } + }// namespace distributed +} \ No newline at end of file diff --git a/mlx/backend/cuda/iterators/strided_iterator.cuh b/mlx/backend/cuda/iterators/strided_iterator.cuh new file mode 100644 index 000000000..3ef8d66bd --- /dev/null +++ b/mlx/backend/cuda/iterators/strided_iterator.cuh @@ -0,0 +1,60 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +namespace mlx::core::cu { + +// RandomAccessIterator for strided access to array entries. +template +class strided_iterator + : public thrust:: + iterator_adaptor, Iterator> { + public: + using super_t = + thrust::iterator_adaptor, Iterator>; + + using reference = typename super_t::reference; + using difference_type = typename super_t::difference_type; + + __host__ __device__ strided_iterator(Iterator it, Stride stride) + : super_t(it), stride_(stride) {} + + __host__ __device__ Stride stride() const { + return stride_; + } + + private: + friend class thrust::iterator_core_access; + + __host__ __device__ bool equal(const strided_iterator& other) const { + return this->base() == other.base(); + } + + __host__ __device__ void advance(difference_type n) { + this->base_reference() += n * stride_; + } + + __host__ __device__ void increment() { + this->base_reference() += stride_; + } + + __host__ __device__ void decrement() { + this->base_reference() -= stride_; + } + + __host__ __device__ difference_type + distance_to(const strided_iterator& other) const { + const difference_type dist = other.base() - this->base(); + _CCCL_ASSERT( + dist % stride() == 0, + "Underlying iterator difference must be divisible by the stride"); + return dist / stride(); + } + + Stride stride_; +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 283aaaf2e..93346d887 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -6,6 +6,7 @@ #include "mlx/backend/cuda/gemms/gemv.h" #include "mlx/backend/gpu/copy.h" #include "mlx/primitives.h" +#include "mlx/utils.h" #include #include diff --git a/mlx/backend/cuda/primitives.cpp b/mlx/backend/cuda/primitives.cpp index c471fa8c2..08a457eaf 100644 --- a/mlx/backend/cuda/primitives.cpp +++ b/mlx/backend/cuda/primitives.cpp @@ -17,28 +17,6 @@ bool fast::ScaledDotProductAttention::use_fallback( return true; } -namespace distributed { -void AllReduce::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { - // Here I assume for now that in is donatable and contiguous. - // TODO - - auto& input = inputs[0]; - auto& output = outputs[0]; - - output.copy_shared_buffer(input); - auto& s = stream(); - switch (reduce_type_) { - case Sum: - distributed::detail::all_sum(group(), input, output, s); - break; - default: - throw std::runtime_error("Only all reduce sum is supported for now"); - } -} -} // namespace distributed - #define NO_GPU_MULTI(func) \ void func::eval_gpu( \ const std::vector& inputs, std::vector& outputs) { \ @@ -79,10 +57,11 @@ NO_GPU(ScaledDotProductAttention) NO_GPU_MULTI(CustomKernel) } // namespace fast -namespace distributed { -NO_GPU_MULTI(AllGather) -NO_GPU_MULTI(Send) -NO_GPU_MULTI(Recv) -} // namespace distributed +// namespace distributed { +// NO_GPU_MULTI(AllReduce) +// NO_GPU_MULTI(AllGather) +// NO_GPU_MULTI(Send) +// NO_GPU_MULTI(Recv) +// } // namespace distributed } // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/segmented_reduce.cu b/mlx/backend/cuda/reduce/segmented_reduce.cu new file mode 100644 index 000000000..114d71809 --- /dev/null +++ b/mlx/backend/cuda/reduce/segmented_reduce.cu @@ -0,0 +1,84 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/cast_op.cuh" +#include "mlx/backend/cuda/reduce/reduce.cuh" + +#include +#include +#include + +namespace mlx::core { + +template +void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) { + // Allocate temporary storage. + size_t size; + CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...)); + array temp(allocator::malloc(size), {static_cast(size)}, uint8); + encoder.add_temporary(temp); + // Run op. + CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data(), size, args...)); +} + +template +void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) { + // Allocate temporary storage. + size_t size; + CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...)); + array temp(allocator::malloc(size), {static_cast(size)}, uint8); + encoder.add_temporary(temp); + // Run op. + CHECK_CUDA_ERROR( + cub::DeviceSegmentedReduce::Reduce(temp.data(), size, args...)); +} + +struct MultiplyOp { + int factor; + __device__ int operator()(int i) { + return i * factor; + } +}; + +void segmented_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan) { + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using InType = cuda_type_t; + using OutType = cu::ReduceResult::type; + auto in_iter = cu::make_cast_iterator( + thrust::device_pointer_cast(in.data())); + auto out_ptr = thrust::device_pointer_cast(out.data()); + auto init = cu::ReduceInit::value(); + + if (plan.type == ContiguousAllReduce) { + cub_all_reduce( + encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream); + } else if (plan.type == ContiguousReduce) { + auto offsets = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()}); + cub_segmented_reduce( + encoder, + in_iter, + out_ptr, + out.size(), + offsets, + offsets + 1, + OP(), + init, + stream); + } else { + throw std::runtime_error("Unsupported plan in segmented_reduce."); + } + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/distributed/distributed.cpp b/mlx/distributed/distributed.cpp index f791ee29e..a65329588 100644 --- a/mlx/distributed/distributed.cpp +++ b/mlx/distributed/distributed.cpp @@ -2,6 +2,7 @@ #include +#include #include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/mpi/mpi.h" @@ -81,7 +82,7 @@ class EmptyGroup : public GroupImpl { } // namespace detail bool is_available() { - return mpi::is_available() || ring::is_available(); + return mpi::is_available() || ring::is_available() || nccl::is_available(); } int Group::rank() const { diff --git a/mlx/distributed/nccl/nccl.cpp b/mlx/distributed/nccl/nccl.cpp index 02b1fc20c..c29851271 100644 --- a/mlx/distributed/nccl/nccl.cpp +++ b/mlx/distributed/nccl/nccl.cpp @@ -271,7 +271,12 @@ class NCCLGroup : public GroupImpl { void all_sum(const array& input, array& output, Stream stream) override { detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { using T = typename decltype(type_tag)::type; - detail::all_reduce_impl(input, output, stream, dt, ncclSum); + all_reduce_impl( + input, + output, + stream, + dt, + ncclSum); }); } @@ -281,6 +286,8 @@ class NCCLGroup : public GroupImpl { void all_gather(const array& input, array& output, Stream stream) override { detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { + auto& encoder = cu::get_command_encoder(stream); + using T = typename decltype(type_tag)::type; CHECK_NCCL(ncclAllGather( input.data(), @@ -288,12 +295,14 @@ class NCCLGroup : public GroupImpl { input.size(), dt, comm_, - cu::get_stream(stream).last_cuda_stream())); + encoder.stream())); }); } void send(const array& input, int dst, Stream stream) override { detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { + auto& encoder = cu::get_command_encoder(stream); + using T = typename decltype(type_tag)::type; CHECK_NCCL(ncclSend( input.data(), @@ -301,20 +310,22 @@ class NCCLGroup : public GroupImpl { dt, dst, comm_, - cu::get_stream(stream).last_cuda_stream())); + encoder.stream())); }); } void recv(array& output, int src, Stream stream) override { detail::dispatch_dtype(output, [&](auto type_tag, ncclDataType_t dt) { using T = typename decltype(type_tag)::type; + auto& encoder = cu::get_command_encoder(stream); + CHECK_NCCL(ncclRecv( output.data(), output.size(), dt, src, comm_, - cu::get_stream(stream).last_cuda_stream())); + encoder.stream())); }); } @@ -339,6 +350,9 @@ class NCCLGroup : public GroupImpl { Stream stream, ncclDataType_t dt, ncclRedOp_t op) { + + auto& encoder = cu::get_command_encoder(stream); + CHECK_NCCL(ncclAllReduce( input.data(), output.data(), @@ -346,7 +360,9 @@ class NCCLGroup : public GroupImpl { dt, op, comm_, - cu::get_stream(stream).last_cuda_stream())); + encoder.stream() + )); + } int rank_, size_;