From 0ce20290b934923e3eb9f04b6098094d56f939f5 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 19 Jun 2025 02:53:41 -0700 Subject: [PATCH] Remove segmented reduce and fix row reduce --- mlx/backend/cuda/CMakeLists.txt | 1 - mlx/backend/cuda/reduce/row_reduce.cu | 2 - mlx/backend/cuda/reduce/segmented_reduce.cu | 84 --------------------- 3 files changed, 87 deletions(-) delete mode 100644 mlx/backend/cuda/reduce/segmented_reduce.cu diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 6487d6aab..bf595b5f9 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -32,7 +32,6 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu - ${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 80c5cc254..0af6f27cc 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -249,8 +249,6 @@ __global__ void row_reduce_looped( size_t full_blocks = args.row_size / (BLOCK_DIM_X * N_READS); size_t final_offset = full_blocks * BLOCK_DIM_X * N_READS; - in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); - for (size_t n = 0; n < args.non_row_reductions; n++) { for (size_t r = 0; r < full_blocks; r++) { T vals[N_READS]; diff --git a/mlx/backend/cuda/reduce/segmented_reduce.cu b/mlx/backend/cuda/reduce/segmented_reduce.cu deleted file mode 100644 index 114d71809..000000000 --- a/mlx/backend/cuda/reduce/segmented_reduce.cu +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/device/cast_op.cuh" -#include "mlx/backend/cuda/reduce/reduce.cuh" - -#include -#include -#include - -namespace mlx::core { - -template -void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) { - // Allocate temporary storage. - size_t size; - CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...)); - array temp(allocator::malloc(size), {static_cast(size)}, uint8); - encoder.add_temporary(temp); - // Run op. - CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data(), size, args...)); -} - -template -void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) { - // Allocate temporary storage. - size_t size; - CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...)); - array temp(allocator::malloc(size), {static_cast(size)}, uint8); - encoder.add_temporary(temp); - // Run op. - CHECK_CUDA_ERROR( - cub::DeviceSegmentedReduce::Reduce(temp.data(), size, args...)); -} - -struct MultiplyOp { - int factor; - __device__ int operator()(int i) { - return i * factor; - } -}; - -void segmented_reduce( - cu::CommandEncoder& encoder, - const array& in, - array& out, - Reduce::ReduceType reduce_type, - const std::vector& axes, - const ReductionPlan& plan) { - encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - using InType = cuda_type_t; - using OutType = cu::ReduceResult::type; - auto in_iter = cu::make_cast_iterator( - thrust::device_pointer_cast(in.data())); - auto out_ptr = thrust::device_pointer_cast(out.data()); - auto init = cu::ReduceInit::value(); - - if (plan.type == ContiguousAllReduce) { - cub_all_reduce( - encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream); - } else if (plan.type == ContiguousReduce) { - auto offsets = thrust::make_transform_iterator( - thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()}); - cub_segmented_reduce( - encoder, - in_iter, - out_ptr, - out.size(), - offsets, - offsets + 1, - OP(), - init, - stream); - } else { - throw std::runtime_error("Unsupported plan in segmented_reduce."); - } - }); - }); - }); -} - -} // namespace mlx::core