faster softmax and logsumexp

This commit is contained in:
Awni Hannun
2025-07-29 14:34:00 -07:00
parent 483221631a
commit 5694f764fc
2 changed files with 3 additions and 2 deletions

View File

@@ -54,7 +54,7 @@ __global__ void logsumexp(const T* in, T* out, int axis_size) {
// https://github.com/NVIDIA/online-softmax // https://github.com/NVIDIA/online-softmax
normalizer = normalizer * softmax_exp(prevmax - maxval); normalizer = normalizer * softmax_exp(prevmax - maxval);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
normalizer = normalizer + softmax_exp(vals[i] - maxval); normalizer = normalizer + softmax_exp(static_cast<AccT>(vals[i]) - maxval);
} }
} }

View File

@@ -57,7 +57,8 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
normalizer = normalizer * softmax_exp(prevmax - maxval); normalizer = normalizer * softmax_exp(prevmax - maxval);
#pragma unroll #pragma unroll
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
normalizer = normalizer + softmax_exp(static_cast<AccT>(vals[i]) - maxval); normalizer =
normalizer + softmax_exp(static_cast<AccT>(vals[i]) - maxval);
} }
} }