Compare commits

..

2 Commits

Author SHA1 Message Date
Cheng
79071bfba4 Fix out-of-bounds default value in logsumexp/softmax (#2213) 2025-05-21 07:25:16 -07:00
Cheng
7774b87cbd Remove redundant simd_sum in logsumexp (#2210) 2025-05-21 07:25:03 -07:00
3 changed files with 9 additions and 9 deletions

View File

@@ -103,8 +103,8 @@ template <typename T, typename AccT = float, int N_READS = 4>
} }
} else { } else {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) vals[i] =
: Limits<AccT>::finite_min; (offset + i < axis_size) ? AccT(in[offset + i]) : Limits<AccT>::min;
} }
} }
prevmax = maxval; prevmax = maxval;
@@ -134,10 +134,7 @@ template <typename T, typename AccT = float, int N_READS = 4>
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
normalizer = simd_sum(local_normalizer[simd_lane_id]); normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_group_id == 0) { if (lid == 0) {
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_lane_id == 0) {
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
} }
}
} }

View File

@@ -128,8 +128,8 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
} }
} else { } else {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) vals[i] =
: Limits<AccT>::finite_min; (offset + i < axis_size) ? AccT(in[offset + i]) : Limits<AccT>::min;
} }
} }
prevmax = maxval; prevmax = maxval;

View File

@@ -1036,6 +1036,9 @@ TEST_CASE("test reduction ops") {
x = array({-inf, -inf}); x = array({-inf, -inf});
CHECK_EQ(logsumexp(x).item<float>(), -inf); CHECK_EQ(logsumexp(x).item<float>(), -inf);
x = repeat(array(-inf), 5000);
CHECK_EQ(logsumexp(x).item<float>(), -inf);
x = array({0.0f, -inf}); x = array({0.0f, -inf});
CHECK_EQ(logsumexp(x).item<float>(), 0.0f); CHECK_EQ(logsumexp(x).item<float>(), 0.0f);