mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
faster softmax and logsumexp
This commit is contained in:
@@ -54,7 +54,7 @@ __global__ void logsumexp(const T* in, T* out, int axis_size) {
|
|||||||
// https://github.com/NVIDIA/online-softmax
|
// https://github.com/NVIDIA/online-softmax
|
||||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
normalizer = normalizer + softmax_exp(vals[i] - maxval);
|
normalizer = normalizer + softmax_exp(static_cast<AccT>(vals[i]) - maxval);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,8 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
|
|||||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
normalizer = normalizer + softmax_exp(static_cast<AccT>(vals[i]) - maxval);
|
normalizer =
|
||||||
|
normalizer + softmax_exp(static_cast<AccT>(vals[i]) - maxval);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user