mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Remove segmented reduce and fix row reduce
This commit is contained in:
parent
6a59c92457
commit
0ce20290b9
@ -32,7 +32,6 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||||
|
@ -249,8 +249,6 @@ __global__ void row_reduce_looped(
|
|||||||
size_t full_blocks = args.row_size / (BLOCK_DIM_X * N_READS);
|
size_t full_blocks = args.row_size / (BLOCK_DIM_X * N_READS);
|
||||||
size_t final_offset = full_blocks * BLOCK_DIM_X * N_READS;
|
size_t final_offset = full_blocks * BLOCK_DIM_X * N_READS;
|
||||||
|
|
||||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
|
||||||
|
|
||||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||||
for (size_t r = 0; r < full_blocks; r++) {
|
for (size_t r = 0; r < full_blocks; r++) {
|
||||||
T vals[N_READS];
|
T vals[N_READS];
|
||||||
|
@ -1,84 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
|
||||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
|
||||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
|
||||||
|
|
||||||
#include <thrust/device_ptr.h>
|
|
||||||
#include <cub/device/device_reduce.cuh>
|
|
||||||
#include <cub/device/device_segmented_reduce.cuh>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
template <typename... Args>
|
|
||||||
void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
|
||||||
// Allocate temporary storage.
|
|
||||||
size_t size;
|
|
||||||
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...));
|
|
||||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
|
||||||
encoder.add_temporary(temp);
|
|
||||||
// Run op.
|
|
||||||
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data<void>(), size, args...));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename... Args>
|
|
||||||
void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
|
||||||
// Allocate temporary storage.
|
|
||||||
size_t size;
|
|
||||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...));
|
|
||||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
|
||||||
encoder.add_temporary(temp);
|
|
||||||
// Run op.
|
|
||||||
CHECK_CUDA_ERROR(
|
|
||||||
cub::DeviceSegmentedReduce::Reduce(temp.data<void>(), size, args...));
|
|
||||||
}
|
|
||||||
|
|
||||||
struct MultiplyOp {
|
|
||||||
int factor;
|
|
||||||
__device__ int operator()(int i) {
|
|
||||||
return i * factor;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
void segmented_reduce(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
Reduce::ReduceType reduce_type,
|
|
||||||
const std::vector<int>& axes,
|
|
||||||
const ReductionPlan& plan) {
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
|
||||||
using InType = cuda_type_t<CTYPE>;
|
|
||||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
|
||||||
auto in_iter = cu::make_cast_iterator<OutType>(
|
|
||||||
thrust::device_pointer_cast(in.data<InType>()));
|
|
||||||
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>());
|
|
||||||
auto init = cu::ReduceInit<OP, InType>::value();
|
|
||||||
|
|
||||||
if (plan.type == ContiguousAllReduce) {
|
|
||||||
cub_all_reduce(
|
|
||||||
encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream);
|
|
||||||
} else if (plan.type == ContiguousReduce) {
|
|
||||||
auto offsets = thrust::make_transform_iterator(
|
|
||||||
thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()});
|
|
||||||
cub_segmented_reduce(
|
|
||||||
encoder,
|
|
||||||
in_iter,
|
|
||||||
out_ptr,
|
|
||||||
out.size(),
|
|
||||||
offsets,
|
|
||||||
offsets + 1,
|
|
||||||
OP(),
|
|
||||||
init,
|
|
||||||
stream);
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error("Unsupported plan in segmented_reduce.");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
Loading…
Reference in New Issue
Block a user