mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-26 15:58:14 +08:00
avoid producing NaN in attention (#2608)
This commit is contained in:
@@ -108,7 +108,7 @@ __global__ void kernel_sdpav_1pass(
|
||||
o[i] = 0.f;
|
||||
}
|
||||
|
||||
U max_score = -INFINITY;
|
||||
U max_score = Limits<U>::finite_min();
|
||||
U sum_exp_score = 0.f;
|
||||
if (sinks && warp_idx == 0) {
|
||||
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
|
||||
@@ -141,9 +141,8 @@ __global__ void kernel_sdpav_1pass(
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
bool is_neg_inf = new_max == -INFINITY;
|
||||
U factor = is_neg_inf ? 1 : exp2f(max_score - new_max);
|
||||
U exp_score = is_neg_inf ? 0 : exp2f(score - new_max);
|
||||
U factor = exp2f(max_score - new_max);
|
||||
U exp_score = exp2f(score - new_max);
|
||||
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
@@ -172,7 +171,7 @@ __global__ void kernel_sdpav_1pass(
|
||||
U factor = exp2f(max_score - new_max);
|
||||
sum_exp_score =
|
||||
cg::reduce(warp, sum_exp_scores[lane_idx] * factor, cg::plus<U>());
|
||||
sum_exp_score = __frcp_rn(sum_exp_score);
|
||||
sum_exp_score = sum_exp_score == 0 ? 0 : __frcp_rn(sum_exp_score);
|
||||
|
||||
// Now we need to aggregate all the outputs
|
||||
PRAGMA_LOOP_UNROLL
|
||||
@@ -274,7 +273,7 @@ __global__ void kernel_sdpav_2pass_1(
|
||||
o[i] = 0.f;
|
||||
}
|
||||
|
||||
U max_score = -INFINITY;
|
||||
U max_score = Limits<U>::finite_min();
|
||||
U sum_exp_score = 0.f;
|
||||
if (sinks && warp_idx == 0 && block_idx == 0) {
|
||||
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
|
||||
@@ -307,9 +306,8 @@ __global__ void kernel_sdpav_2pass_1(
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
bool is_neg_inf = new_max == -INFINITY;
|
||||
U factor = is_neg_inf ? 1 : exp2f(max_score - new_max);
|
||||
U exp_score = is_neg_inf ? 0 : exp2f(score - new_max);
|
||||
U factor = exp2f(max_score - new_max);
|
||||
U exp_score = exp2f(score - new_max);
|
||||
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
@@ -421,7 +419,7 @@ __global__ void kernel_sdpav_2pass_2(
|
||||
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
|
||||
U factor = exp2f(max_score - new_max);
|
||||
U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus<U>());
|
||||
sum_exp_score = __frcp_rn(sum_exp_score);
|
||||
sum_exp_score = sum_exp_score == 0 ? 0 : __frcp_rn(sum_exp_score);
|
||||
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int i = 0; i < v_per_thread; i++) {
|
||||
|
@@ -87,7 +87,7 @@ template <typename T, int D, int V = D>
|
||||
o[i] = 0;
|
||||
}
|
||||
|
||||
U max_score = -INFINITY;
|
||||
U max_score = Limits<U>::finite_min;
|
||||
U sum_exp_score = 0;
|
||||
if (has_sinks && simd_gid == 0) {
|
||||
max_score = static_cast<U>(sinks[q_batch_head_idx % num_q_heads]);
|
||||
@@ -122,9 +122,8 @@ template <typename T, int D, int V = D>
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
bool is_neg_inf = new_max == -INFINITY;
|
||||
U factor = is_neg_inf ? 1.0 : fast::exp(max_score - new_max);
|
||||
U exp_score = is_neg_inf ? 0.0 : fast::exp(score - new_max);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
U exp_score = fast::exp(score - new_max);
|
||||
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
@@ -163,7 +162,8 @@ template <typename T, int D, int V = D>
|
||||
for (int i = 0; i < v_per_thread; i++) {
|
||||
outputs[simd_lid * BD + simd_gid] = o[i];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
|
||||
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor);
|
||||
o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
@@ -259,7 +259,7 @@ template <typename T, int D, int V = D>
|
||||
o[i] = 0;
|
||||
}
|
||||
|
||||
U max_score = -INFINITY;
|
||||
U max_score = Limits<U>::finite_min;
|
||||
U sum_exp_score = 0;
|
||||
if (has_sinks && block_idx == 0 && simd_gid == 0) {
|
||||
int q_head_idx = q_batch_head_idx % num_q_heads;
|
||||
@@ -289,9 +289,6 @@ template <typename T, int D, int V = D>
|
||||
score += q[i] * k[i];
|
||||
}
|
||||
score = simd_sum(score);
|
||||
if (score < Limits<T>::finite_min) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (float_mask) {
|
||||
score += fmask[0];
|
||||
@@ -404,7 +401,8 @@ template <typename T, int D>
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
outputs[simd_lid * BD + simd_gid] = o[i];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
|
||||
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor);
|
||||
o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
|
@@ -732,10 +732,7 @@ array scaled_dot_product_attention(
|
||||
}
|
||||
if (mask.dtype() == bool_) {
|
||||
scores = where(
|
||||
mask,
|
||||
scores,
|
||||
array(-std::numeric_limits<float>::infinity(), scores.dtype()),
|
||||
s);
|
||||
mask, scores, array(finfo(scores.dtype()).min, scores.dtype()), s);
|
||||
} else {
|
||||
scores = add(scores, mask, s);
|
||||
}
|
||||
|
@@ -38,7 +38,7 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None, sinks=None):
|
||||
mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
|
||||
|
||||
if mask.dtype == mx.bool_:
|
||||
scores = mx.where(mask, scores, -np.float32(np.inf))
|
||||
scores = mx.where(mask, scores, mx.finfo(scores.dtype).min)
|
||||
else:
|
||||
scores += mask
|
||||
|
||||
@@ -410,18 +410,15 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
|
||||
def test_fully_masked(self):
|
||||
Lkv = 8
|
||||
masks = [mx.array(False), mx.array(-float("inf"))]
|
||||
for mask in masks:
|
||||
for D in [4, 128]:
|
||||
for Lq in [1, 8]:
|
||||
q = mx.random.normal(shape=(1, 4, Lq, D))
|
||||
k = mx.random.normal(shape=(1, 4, Lkv, D))
|
||||
v = mx.random.normal(shape=(1, 4, Lkv, D))
|
||||
mask = mx.array(False)
|
||||
for D in [128]:
|
||||
for Lq in [1, 8, 32]:
|
||||
q = mx.random.normal(shape=(1, 4, Lq, D))
|
||||
k = mx.random.normal(shape=(1, 4, Lkv, D))
|
||||
v = mx.random.normal(shape=(1, 4, Lkv, D))
|
||||
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, mask=mask, scale=1
|
||||
)
|
||||
self.assertTrue(mx.all(mx.isnan(out)))
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1)
|
||||
self.assertFalse(mx.any(mx.isnan(out)))
|
||||
|
||||
def test_inf_score(self):
|
||||
Lkv = 8
|
||||
|
Reference in New Issue
Block a user