mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Add an init reduce
This commit is contained in:
parent
cc4b995723
commit
818e8e663e
@ -31,6 +31,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
||||||
|
@ -25,7 +25,8 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
if (in.size() == 0) {
|
if (in.size() == 0) {
|
||||||
throw std::runtime_error("Should never reach here.");
|
init_reduce(encoder, in, out, reduce_type_);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reduce.
|
// Reduce.
|
||||||
|
51
mlx/backend/cuda/reduce/init_reduce.cu
Normal file
51
mlx/backend/cuda/reduce/init_reduce.cu
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
__global__ void init_reduce(U* out, size_t size) {
|
||||||
|
auto index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
out[index] = ReduceInit<Op, T>::value();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void init_reduce(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type) {
|
||||||
|
// Allocate if needed
|
||||||
|
if (out.data_shared_ptr() == nullptr) {
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
|
using T = cuda_type_t<CTYPE>;
|
||||||
|
using U = cu::ReduceResult<OP, T>::type;
|
||||||
|
auto kernel = cu::init_reduce<T, U, OP>;
|
||||||
|
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
|
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
|
||||||
|
grid.x = (grid.x + 1023) / 1024;
|
||||||
|
kernel<<<grid, block, 0, stream>>>(out.data<U>(), out.size());
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -53,14 +53,6 @@ void all_reduce(
|
|||||||
array& out,
|
array& out,
|
||||||
Reduce::ReduceType reduce_type);
|
Reduce::ReduceType reduce_type);
|
||||||
|
|
||||||
void segmented_reduce(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
Reduce::ReduceType reduce_type,
|
|
||||||
const std::vector<int>& axes,
|
|
||||||
const ReductionPlan& plan);
|
|
||||||
|
|
||||||
void row_reduce(
|
void row_reduce(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
const array& in,
|
const array& in,
|
||||||
@ -77,4 +69,10 @@ void col_reduce(
|
|||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const ReductionPlan& plan);
|
const ReductionPlan& plan);
|
||||||
|
|
||||||
|
void init_reduce(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -234,7 +234,7 @@ void row_reduce_simple(
|
|||||||
using U = cu::ReduceResult<OP, T>::type;
|
using U = cu::ReduceResult<OP, T>::type;
|
||||||
|
|
||||||
// Calculate the grid and block dims
|
// Calculate the grid and block dims
|
||||||
size_t reductions = plan.shape.back() / N_READS;
|
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
|
||||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
int threads = std::min(1024UL, reductions);
|
int threads = std::min(1024UL, reductions);
|
||||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||||
@ -284,7 +284,7 @@ void row_reduce_looped(
|
|||||||
// Calculate the grid and block dims
|
// Calculate the grid and block dims
|
||||||
args.convert_shapes_to_contiguous(x, axes);
|
args.convert_shapes_to_contiguous(x, axes);
|
||||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
size_t reductions = args.row_size / N_READS;
|
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
|
||||||
int threads = std::min(1024UL, reductions);
|
int threads = std::min(1024UL, reductions);
|
||||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||||
dim3 block(threads, 1, 1);
|
dim3 block(threads, 1, 1);
|
||||||
|
Loading…
Reference in New Issue
Block a user