avoid producing NaN in attention (#2608)

This commit is contained in:
Awni Hannun
2025-09-22 13:10:43 -07:00
committed by GitHub
parent aa9d44b3d4
commit 711a645807
4 changed files with 26 additions and 36 deletions

View File

@@ -87,7 +87,7 @@ template <typename T, int D, int V = D>
o[i] = 0;
}
U max_score = -INFINITY;
U max_score = Limits<U>::finite_min;
U sum_exp_score = 0;
if (has_sinks && simd_gid == 0) {
max_score = static_cast<U>(sinks[q_batch_head_idx % num_q_heads]);
@@ -122,9 +122,8 @@ template <typename T, int D, int V = D>
// Update the accumulators
U new_max = max(max_score, score);
bool is_neg_inf = new_max == -INFINITY;
U factor = is_neg_inf ? 1.0 : fast::exp(max_score - new_max);
U exp_score = is_neg_inf ? 0.0 : fast::exp(score - new_max);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
@@ -163,7 +162,8 @@ template <typename T, int D, int V = D>
for (int i = 0; i < v_per_thread; i++) {
outputs[simd_lid * BD + simd_gid] = o[i];
threadgroup_barrier(mem_flags::mem_threadgroup);
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor);
o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score);
threadgroup_barrier(mem_flags::mem_threadgroup);
}
@@ -259,7 +259,7 @@ template <typename T, int D, int V = D>
o[i] = 0;
}
U max_score = -INFINITY;
U max_score = Limits<U>::finite_min;
U sum_exp_score = 0;
if (has_sinks && block_idx == 0 && simd_gid == 0) {
int q_head_idx = q_batch_head_idx % num_q_heads;
@@ -289,9 +289,6 @@ template <typename T, int D, int V = D>
score += q[i] * k[i];
}
score = simd_sum(score);
if (score < Limits<T>::finite_min) {
continue;
}
if (float_mask) {
score += fmask[0];
@@ -404,7 +401,8 @@ template <typename T, int D>
for (int i = 0; i < elem_per_thread; i++) {
outputs[simd_lid * BD + simd_gid] = o[i];
threadgroup_barrier(mem_flags::mem_threadgroup);
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor);
o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score);
threadgroup_barrier(mem_flags::mem_threadgroup);
}