Compare commits

...

2 Commits

Author SHA1 Message Date
Cheng
79071bfba4 Fix out-of-bounds default value in logsumexp/softmax (#2213) 2025-05-21 07:25:16 -07:00
Cheng
7774b87cbd Remove redundant simd_sum in logsumexp (#2210) 2025-05-21 07:25:03 -07:00
3 changed files with 9 additions and 9 deletions

View File

@@ -103,8 +103,8 @@ template <typename T, typename AccT = float, int N_READS = 4>
}
} else {
for (int i = 0; i < N_READS; i++) {
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
: Limits<AccT>::finite_min;
vals[i] =
(offset + i < axis_size) ? AccT(in[offset + i]) : Limits<AccT>::min;
}
}
prevmax = maxval;
@@ -134,10 +134,7 @@ template <typename T, typename AccT = float, int N_READS = 4>
threadgroup_barrier(mem_flags::mem_threadgroup);
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_group_id == 0) {
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_lane_id == 0) {
if (lid == 0) {
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
}
}
}

View File

@@ -128,8 +128,8 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
}
} else {
for (int i = 0; i < N_READS; i++) {
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
: Limits<AccT>::finite_min;
vals[i] =
(offset + i < axis_size) ? AccT(in[offset + i]) : Limits<AccT>::min;
}
}
prevmax = maxval;

View File

@@ -1036,6 +1036,9 @@ TEST_CASE("test reduction ops") {
x = array({-inf, -inf});
CHECK_EQ(logsumexp(x).item<float>(), -inf);
x = repeat(array(-inf), 5000);
CHECK_EQ(logsumexp(x).item<float>(), -inf);
x = array({0.0f, -inf});
CHECK_EQ(logsumexp(x).item<float>(), 0.0f);