Remove redundant simd_sum in logsumexp (#2210)

This commit is contained in:
Cheng 2025-05-21 23:25:03 +09:00 committed by GitHub
parent 35c87741cf
commit 7774b87cbd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -134,10 +134,7 @@ template <typename T, typename AccT = float, int N_READS = 4>
threadgroup_barrier(mem_flags::mem_threadgroup);
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_group_id == 0) {
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_lane_id == 0) {
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
}
if (lid == 0) {
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
}
}