mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-21 18:28:11 +08:00
avoid producing NaN in attention (#2608)
This commit is contained in:
@@ -87,7 +87,7 @@ template <typename T, int D, int V = D>
|
||||
o[i] = 0;
|
||||
}
|
||||
|
||||
U max_score = -INFINITY;
|
||||
U max_score = Limits<U>::finite_min;
|
||||
U sum_exp_score = 0;
|
||||
if (has_sinks && simd_gid == 0) {
|
||||
max_score = static_cast<U>(sinks[q_batch_head_idx % num_q_heads]);
|
||||
@@ -122,9 +122,8 @@ template <typename T, int D, int V = D>
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
bool is_neg_inf = new_max == -INFINITY;
|
||||
U factor = is_neg_inf ? 1.0 : fast::exp(max_score - new_max);
|
||||
U exp_score = is_neg_inf ? 0.0 : fast::exp(score - new_max);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
U exp_score = fast::exp(score - new_max);
|
||||
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
@@ -163,7 +162,8 @@ template <typename T, int D, int V = D>
|
||||
for (int i = 0; i < v_per_thread; i++) {
|
||||
outputs[simd_lid * BD + simd_gid] = o[i];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
|
||||
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor);
|
||||
o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
@@ -259,7 +259,7 @@ template <typename T, int D, int V = D>
|
||||
o[i] = 0;
|
||||
}
|
||||
|
||||
U max_score = -INFINITY;
|
||||
U max_score = Limits<U>::finite_min;
|
||||
U sum_exp_score = 0;
|
||||
if (has_sinks && block_idx == 0 && simd_gid == 0) {
|
||||
int q_head_idx = q_batch_head_idx % num_q_heads;
|
||||
@@ -289,9 +289,6 @@ template <typename T, int D, int V = D>
|
||||
score += q[i] * k[i];
|
||||
}
|
||||
score = simd_sum(score);
|
||||
if (score < Limits<T>::finite_min) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (float_mask) {
|
||||
score += fmask[0];
|
||||
@@ -404,7 +401,8 @@ template <typename T, int D>
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
outputs[simd_lid * BD + simd_gid] = o[i];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
|
||||
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor);
|
||||
o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user