diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index bf595b5f9..8130d396f 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -31,6 +31,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu + ${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu diff --git a/mlx/backend/cuda/reduce.cu b/mlx/backend/cuda/reduce.cu index 8922000e2..8936bbf71 100644 --- a/mlx/backend/cuda/reduce.cu +++ b/mlx/backend/cuda/reduce.cu @@ -25,7 +25,8 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { auto& encoder = cu::get_command_encoder(s); if (in.size() == 0) { - throw std::runtime_error("Should never reach here."); + init_reduce(encoder, in, out, reduce_type_); + return; } // Reduce. diff --git a/mlx/backend/cuda/reduce/init_reduce.cu b/mlx/backend/cuda/reduce/init_reduce.cu new file mode 100644 index 000000000..a500dc04e --- /dev/null +++ b/mlx/backend/cuda/reduce/init_reduce.cu @@ -0,0 +1,51 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/reduce/reduce.cuh" + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void init_reduce(U* out, size_t size) { + auto index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = ReduceInit::value(); + } +} + +} // namespace cu + +void init_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type) { + // Allocate if needed + if (out.data_shared_ptr() == nullptr) { + out.set_data(allocator::malloc(out.nbytes())); + } + + encoder.set_input_array(in); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using T = cuda_type_t; + using U = cu::ReduceResult::type; + auto kernel = cu::init_reduce; + dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); + dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1); + grid.x = (grid.x + 1023) / 1024; + kernel<<>>(out.data(), out.size()); + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/reduce.cuh b/mlx/backend/cuda/reduce/reduce.cuh index 07041efce..a7262bcc2 100644 --- a/mlx/backend/cuda/reduce/reduce.cuh +++ b/mlx/backend/cuda/reduce/reduce.cuh @@ -53,14 +53,6 @@ void all_reduce( array& out, Reduce::ReduceType reduce_type); -void segmented_reduce( - cu::CommandEncoder& encoder, - const array& in, - array& out, - Reduce::ReduceType reduce_type, - const std::vector& axes, - const ReductionPlan& plan); - void row_reduce( cu::CommandEncoder& encoder, const array& in, @@ -77,4 +69,10 @@ void col_reduce( const std::vector& axes, const ReductionPlan& plan); +void init_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type); + } // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 12bf8897b..4c005cc52 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -234,7 +234,7 @@ void row_reduce_simple( using U = cu::ReduceResult::type; // Calculate the grid and block dims - size_t reductions = plan.shape.back() / N_READS; + size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS; dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); int threads = std::min(1024UL, reductions); threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; @@ -284,7 +284,7 @@ void row_reduce_looped( // Calculate the grid and block dims args.convert_shapes_to_contiguous(x, axes); dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); - size_t reductions = args.row_size / N_READS; + size_t reductions = (args.row_size + N_READS - 1) / N_READS; int threads = std::min(1024UL, reductions); threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; dim3 block(threads, 1, 1);