avoid producing NaN in attention (#2608)

This commit is contained in:
Awni Hannun
2025-09-22 13:10:43 -07:00
committed by GitHub
parent aa9d44b3d4
commit 711a645807
4 changed files with 26 additions and 36 deletions

View File

@@ -108,7 +108,7 @@ __global__ void kernel_sdpav_1pass(
o[i] = 0.f;
}
U max_score = -INFINITY;
U max_score = Limits<U>::finite_min();
U sum_exp_score = 0.f;
if (sinks && warp_idx == 0) {
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
@@ -141,9 +141,8 @@ __global__ void kernel_sdpav_1pass(
// Update the accumulators
U new_max = max(max_score, score);
bool is_neg_inf = new_max == -INFINITY;
U factor = is_neg_inf ? 1 : exp2f(max_score - new_max);
U exp_score = is_neg_inf ? 0 : exp2f(score - new_max);
U factor = exp2f(max_score - new_max);
U exp_score = exp2f(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
@@ -172,7 +171,7 @@ __global__ void kernel_sdpav_1pass(
U factor = exp2f(max_score - new_max);
sum_exp_score =
cg::reduce(warp, sum_exp_scores[lane_idx] * factor, cg::plus<U>());
sum_exp_score = __frcp_rn(sum_exp_score);
sum_exp_score = sum_exp_score == 0 ? 0 : __frcp_rn(sum_exp_score);
// Now we need to aggregate all the outputs
PRAGMA_LOOP_UNROLL
@@ -274,7 +273,7 @@ __global__ void kernel_sdpav_2pass_1(
o[i] = 0.f;
}
U max_score = -INFINITY;
U max_score = Limits<U>::finite_min();
U sum_exp_score = 0.f;
if (sinks && warp_idx == 0 && block_idx == 0) {
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
@@ -307,9 +306,8 @@ __global__ void kernel_sdpav_2pass_1(
// Update the accumulators
U new_max = max(max_score, score);
bool is_neg_inf = new_max == -INFINITY;
U factor = is_neg_inf ? 1 : exp2f(max_score - new_max);
U exp_score = is_neg_inf ? 0 : exp2f(score - new_max);
U factor = exp2f(max_score - new_max);
U exp_score = exp2f(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
@@ -421,7 +419,7 @@ __global__ void kernel_sdpav_2pass_2(
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
U factor = exp2f(max_score - new_max);
U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus<U>());
sum_exp_score = __frcp_rn(sum_exp_score);
sum_exp_score = sum_exp_score == 0 ? 0 : __frcp_rn(sum_exp_score);
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {