mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
2 Commits
ibv-backen
...
e42e06046e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e42e06046e | ||
|
|
17432e7885 |
@@ -181,6 +181,47 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op, int N_READS = 4>
|
||||||
|
__global__ void col_reduce_small(
|
||||||
|
const T* in,
|
||||||
|
U* out,
|
||||||
|
const __grid_constant__ ColReduceArgs args,
|
||||||
|
size_t total) {
|
||||||
|
Op op;
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
|
||||||
|
const auto idx = grid.thread_rank() * N_READS;
|
||||||
|
const auto before_axis = idx / args.reduction_stride;
|
||||||
|
const auto after_axis = idx % args.reduction_stride;
|
||||||
|
const auto offset =
|
||||||
|
before_axis * args.reduction_stride * args.reduction_size + after_axis;
|
||||||
|
|
||||||
|
if (idx >= total) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
in += offset;
|
||||||
|
out += idx;
|
||||||
|
|
||||||
|
AlignedVector<U, N_READS> accumulator;
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
accumulator[i] = ReduceInit<Op, T>::value();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < args.reduction_size; i++) {
|
||||||
|
auto values = load_vector<N_READS>(in, 0);
|
||||||
|
|
||||||
|
for (int j = 0; j < N_READS; j++) {
|
||||||
|
accumulator[j] = op(accumulator[j], cast_to<U>(values[j]));
|
||||||
|
}
|
||||||
|
|
||||||
|
in += args.reduction_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
store_vector(out, 0, accumulator);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace cu
|
} // namespace cu
|
||||||
|
|
||||||
inline auto output_grid_for_col_reduce(
|
inline auto output_grid_for_col_reduce(
|
||||||
@@ -236,6 +277,43 @@ void col_reduce_looped(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void col_reduce_small(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const ReductionPlan& plan,
|
||||||
|
cu::ColReduceArgs args) {
|
||||||
|
// Allocate data for the output using in's layout to access them as
|
||||||
|
// contiguously as possible.
|
||||||
|
allocate_same_layout(out, in, axes);
|
||||||
|
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||||
|
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||||
|
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||||
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
using U = typename cu::ReduceResult<OP, T>::type;
|
||||||
|
|
||||||
|
constexpr int N_READS = 16 / sizeof(T);
|
||||||
|
auto tmp_grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
|
auto [grid, block] = get_grid_and_block(tmp_grid.x, tmp_grid.y, 1);
|
||||||
|
auto kernel = cu::col_reduce_small<T, U, OP, N_READS>;
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
grid,
|
||||||
|
block,
|
||||||
|
0,
|
||||||
|
in.data<T>(),
|
||||||
|
out.data<U>(),
|
||||||
|
args,
|
||||||
|
out.size());
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void col_reduce(
|
void col_reduce(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
const array& in,
|
const array& in,
|
||||||
@@ -258,8 +336,16 @@ void col_reduce(
|
|||||||
// Make the args struct to help route to the best kernel
|
// Make the args struct to help route to the best kernel
|
||||||
cu::ColReduceArgs args(in, plan, axes);
|
cu::ColReduceArgs args(in, plan, axes);
|
||||||
|
|
||||||
|
// Small col reduce with a single or contiguous reduction axis
|
||||||
|
if (args.non_col_reductions == 1 && args.reduction_size <= 32 &&
|
||||||
|
args.reduction_stride % (16 / in.itemsize()) == 0) {
|
||||||
|
col_reduce_small(
|
||||||
|
encoder, in, out, reduce_type, axes, plan, std::move(args));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Fallback col reduce
|
// Fallback col reduce
|
||||||
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
|
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, std::move(args));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
Reference in New Issue
Block a user