Fix out-of-bounds default value in logsumexp/softmax (#2213)

This commit is contained in:
Cheng 2025-05-21 23:25:16 +09:00 committed by GitHub
parent 7774b87cbd
commit 79071bfba4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 7 additions and 4 deletions

View File

@ -103,8 +103,8 @@ template <typename T, typename AccT = float, int N_READS = 4>
} }
} else { } else {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) vals[i] =
: Limits<AccT>::finite_min; (offset + i < axis_size) ? AccT(in[offset + i]) : Limits<AccT>::min;
} }
} }
prevmax = maxval; prevmax = maxval;

View File

@ -128,8 +128,8 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
} }
} else { } else {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) vals[i] =
: Limits<AccT>::finite_min; (offset + i < axis_size) ? AccT(in[offset + i]) : Limits<AccT>::min;
} }
} }
prevmax = maxval; prevmax = maxval;

View File

@ -1036,6 +1036,9 @@ TEST_CASE("test reduction ops") {
x = array({-inf, -inf}); x = array({-inf, -inf});
CHECK_EQ(logsumexp(x).item<float>(), -inf); CHECK_EQ(logsumexp(x).item<float>(), -inf);
x = repeat(array(-inf), 5000);
CHECK_EQ(logsumexp(x).item<float>(), -inf);
x = array({0.0f, -inf}); x = array({0.0f, -inf});
CHECK_EQ(logsumexp(x).item<float>(), 0.0f); CHECK_EQ(logsumexp(x).item<float>(), 0.0f);