consistently handle all -inf in softmax (#1470)

This commit is contained in:
Awni Hannun
2024-10-08 09:54:02 -07:00
committed by GitHub
parent 3274c6a087
commit 1fa0d20a30
3 changed files with 13 additions and 6 deletions

View File

@@ -32,12 +32,12 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
}
} else {
for (int i = 0; i < N_READS; i++) {
ld[i] = ((lid * N_READS + i) < axis_size) ? AccT(in[i])
: Limits<AccT>::finite_min;
ld[i] =
((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits<AccT>::min;
}
}
if (simd_group_id == 0) {
local_max[simd_lane_id] = Limits<AccT>::finite_min;
local_max[simd_lane_id] = Limits<AccT>::min;
local_normalizer[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);