mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-06 10:54:11 +08:00
[CUDA] Add a small column specialization to reduce (#2642)
This commit is contained in:
committed by
GitHub
parent
b0cc71ae71
commit
c2c3e0b0a2
@@ -181,6 +181,47 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = 4>
|
||||
__global__ void col_reduce_small(
|
||||
const T* in,
|
||||
U* out,
|
||||
const __grid_constant__ ColReduceArgs args,
|
||||
size_t total) {
|
||||
Op op;
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
const auto idx = grid.thread_rank() * N_READS;
|
||||
const auto before_axis = idx / args.reduction_stride;
|
||||
const auto after_axis = idx % args.reduction_stride;
|
||||
const auto offset =
|
||||
before_axis * args.reduction_stride * args.reduction_size + after_axis;
|
||||
|
||||
if (idx >= total) {
|
||||
return;
|
||||
}
|
||||
|
||||
in += offset;
|
||||
out += idx;
|
||||
|
||||
AlignedVector<U, N_READS> accumulator;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
accumulator[i] = ReduceInit<Op, T>::value();
|
||||
}
|
||||
|
||||
for (int i = 0; i < args.reduction_size; i++) {
|
||||
auto values = load_vector<N_READS>(in, 0);
|
||||
|
||||
for (int j = 0; j < N_READS; j++) {
|
||||
accumulator[j] = op(accumulator[j], cast_to<U>(values[j]));
|
||||
}
|
||||
|
||||
in += args.reduction_stride;
|
||||
}
|
||||
|
||||
store_vector(out, 0, accumulator);
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
inline auto output_grid_for_col_reduce(
|
||||
@@ -206,7 +247,7 @@ void col_reduce_looped(
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan,
|
||||
cu::ColReduceArgs args) {
|
||||
const cu::ColReduceArgs& args) {
|
||||
// Allocate data for the output using in's layout to access them as
|
||||
// contiguously as possible.
|
||||
allocate_same_layout(out, in, axes);
|
||||
@@ -230,12 +271,55 @@ void col_reduce_looped(
|
||||
auto kernel =
|
||||
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel, grid, blocks, 0, indata, out.data<U>(), args);
|
||||
kernel,
|
||||
grid,
|
||||
blocks,
|
||||
0,
|
||||
indata,
|
||||
out.data<U>(),
|
||||
static_cast<cu::ColReduceArgs>(args));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void col_reduce_small(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan,
|
||||
const cu::ColReduceArgs& args) {
|
||||
// Allocate data for the output using in's layout to access them as
|
||||
// contiguously as possible.
|
||||
allocate_same_layout(out, in, axes);
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
|
||||
constexpr int N_READS = 16 / sizeof(T);
|
||||
auto tmp_grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
auto [grid, block] = get_grid_and_block(tmp_grid.x, tmp_grid.y, 1);
|
||||
auto kernel = cu::col_reduce_small<T, U, OP, N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
grid,
|
||||
block,
|
||||
0,
|
||||
in.data<T>(),
|
||||
out.data<U>(),
|
||||
static_cast<cu::ColReduceArgs>(args),
|
||||
out.size());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void col_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
@@ -258,6 +342,13 @@ void col_reduce(
|
||||
// Make the args struct to help route to the best kernel
|
||||
cu::ColReduceArgs args(in, plan, axes);
|
||||
|
||||
// Small col reduce with a single or contiguous reduction axis
|
||||
if (args.non_col_reductions == 1 && args.reduction_size <= 32 &&
|
||||
args.reduction_stride % (16 / in.itemsize()) == 0) {
|
||||
col_reduce_small(encoder, in, out, reduce_type, axes, plan, args);
|
||||
return;
|
||||
}
|
||||
|
||||
// Fallback col reduce
|
||||
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user