// Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/kernels/cast_op.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh" #include #include #include namespace mlx::core { template void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) { // Allocate temporary storage. size_t size; CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...)); array temp(allocator::malloc(size), {static_cast(size)}, uint8); encoder.add_temporary(temp); // Run op. CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data(), size, args...)); } template void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) { // Allocate temporary storage. size_t size; CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...)); array temp(allocator::malloc(size), {static_cast(size)}, uint8); encoder.add_temporary(temp); // Run op. CHECK_CUDA_ERROR( cub::DeviceSegmentedReduce::Reduce(temp.data(), size, args...)); } struct MultiplyOp { int factor; __device__ int operator()(int i) { return i * factor; } }; void segmented_reduce( cu::CommandEncoder& encoder, const array& in, array& out, Reduce::ReduceType reduce_type, const std::vector& axes, const ReductionPlan& plan) { encoder.launch_kernel([&](cudaStream_t stream) { MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { using InType = cuda_type_t; using OutType = cu::ReduceResult::type; auto in_iter = cu::make_cast_iterator( thrust::device_pointer_cast(in.data())); auto out_ptr = thrust::device_pointer_cast(out.data()); auto init = cu::ReduceInit::value(); if (plan.type == ContiguousAllReduce) { cub_all_reduce( encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream); } else if (plan.type == ContiguousReduce) { auto offsets = thrust::make_transform_iterator( thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()}); cub_segmented_reduce( encoder, in_iter, out_ptr, out.size(), offsets, offsets + 1, OP(), init, stream); } else { throw std::runtime_error("Unsupported plan in segmented_reduce."); } }); }); }); } } // namespace mlx::core