mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Cuda faster softmax (#2435)
* faster softmax and logsumexp * faster softmax and logsumexp * format
This commit is contained in:
@@ -43,20 +43,19 @@ __global__ void logsumexp(const T* in, T* out, int axis_size) {
|
||||
AccT maxval = Limits<AccT>::finite_min();
|
||||
AccT normalizer = 0;
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
||||
AccT vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
r * BLOCK_DIM + block.thread_rank(),
|
||||
make_cast_iterator<AccT>(in),
|
||||
vals,
|
||||
axis_size,
|
||||
Limits<AccT>::min());
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
auto vals = load_vector<N_READS>(in, index, axis_size, Limits<T>::min());
|
||||
prevmax = maxval;
|
||||
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
maxval = max_op(maxval, static_cast<AccT>(vals[i]));
|
||||
}
|
||||
// Online normalizer calculation for softmax:
|
||||
// https://github.com/NVIDIA/online-softmax
|
||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
normalizer = normalizer + softmax_exp(vals[i] - maxval);
|
||||
normalizer =
|
||||
normalizer + softmax_exp(static_cast<AccT>(vals[i]) - maxval);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,9 +142,9 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) {
|
||||
constexpr int N_READS = 4;
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
constexpr int N_READS = 16 / sizeof(DataType);
|
||||
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::logsumexp<DataType, float, block_dim(), N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
|
||||
Reference in New Issue
Block a user