diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cu b/mlx/backend/cuda/scaled_dot_product_attention.cu index 7d5437ef4..825da5cd3 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cu +++ b/mlx/backend/cuda/scaled_dot_product_attention.cu @@ -108,7 +108,7 @@ __global__ void kernel_sdpav_1pass( o[i] = 0.f; } - U max_score = -INFINITY; + U max_score = Limits::finite_min(); U sum_exp_score = 0.f; if (sinks && warp_idx == 0) { max_score = M_LOG2E * static_cast(sinks[head_idx]); @@ -141,9 +141,8 @@ __global__ void kernel_sdpav_1pass( // Update the accumulators U new_max = max(max_score, score); - bool is_neg_inf = new_max == -INFINITY; - U factor = is_neg_inf ? 1 : exp2f(max_score - new_max); - U exp_score = is_neg_inf ? 0 : exp2f(score - new_max); + U factor = exp2f(max_score - new_max); + U exp_score = exp2f(score - new_max); max_score = new_max; sum_exp_score = sum_exp_score * factor + exp_score; @@ -172,7 +171,7 @@ __global__ void kernel_sdpav_1pass( U factor = exp2f(max_score - new_max); sum_exp_score = cg::reduce(warp, sum_exp_scores[lane_idx] * factor, cg::plus()); - sum_exp_score = __frcp_rn(sum_exp_score); + sum_exp_score = sum_exp_score == 0 ? 0 : __frcp_rn(sum_exp_score); // Now we need to aggregate all the outputs PRAGMA_LOOP_UNROLL @@ -274,7 +273,7 @@ __global__ void kernel_sdpav_2pass_1( o[i] = 0.f; } - U max_score = -INFINITY; + U max_score = Limits::finite_min(); U sum_exp_score = 0.f; if (sinks && warp_idx == 0 && block_idx == 0) { max_score = M_LOG2E * static_cast(sinks[head_idx]); @@ -307,9 +306,8 @@ __global__ void kernel_sdpav_2pass_1( // Update the accumulators U new_max = max(max_score, score); - bool is_neg_inf = new_max == -INFINITY; - U factor = is_neg_inf ? 1 : exp2f(max_score - new_max); - U exp_score = is_neg_inf ? 0 : exp2f(score - new_max); + U factor = exp2f(max_score - new_max); + U exp_score = exp2f(score - new_max); max_score = new_max; sum_exp_score = sum_exp_score * factor + exp_score; @@ -421,7 +419,7 @@ __global__ void kernel_sdpav_2pass_2( U new_max = cg::reduce(warp, max_score, cg::greater()); U factor = exp2f(max_score - new_max); U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus()); - sum_exp_score = __frcp_rn(sum_exp_score); + sum_exp_score = sum_exp_score == 0 ? 0 : __frcp_rn(sum_exp_score); PRAGMA_LOOP_UNROLL for (int i = 0; i < v_per_thread; i++) { diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index b7ded1a69..96d22d8e4 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -87,7 +87,7 @@ template o[i] = 0; } - U max_score = -INFINITY; + U max_score = Limits::finite_min; U sum_exp_score = 0; if (has_sinks && simd_gid == 0) { max_score = static_cast(sinks[q_batch_head_idx % num_q_heads]); @@ -122,9 +122,8 @@ template // Update the accumulators U new_max = max(max_score, score); - bool is_neg_inf = new_max == -INFINITY; - U factor = is_neg_inf ? 1.0 : fast::exp(max_score - new_max); - U exp_score = is_neg_inf ? 0.0 : fast::exp(score - new_max); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); max_score = new_max; sum_exp_score = sum_exp_score * factor + exp_score; @@ -163,7 +162,8 @@ template for (int i = 0; i < v_per_thread; i++) { outputs[simd_lid * BD + simd_gid] = o[i]; threadgroup_barrier(mem_flags::mem_threadgroup); - o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score; + o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor); + o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score); threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -259,7 +259,7 @@ template o[i] = 0; } - U max_score = -INFINITY; + U max_score = Limits::finite_min; U sum_exp_score = 0; if (has_sinks && block_idx == 0 && simd_gid == 0) { int q_head_idx = q_batch_head_idx % num_q_heads; @@ -289,9 +289,6 @@ template score += q[i] * k[i]; } score = simd_sum(score); - if (score < Limits::finite_min) { - continue; - } if (float_mask) { score += fmask[0]; @@ -404,7 +401,8 @@ template for (int i = 0; i < elem_per_thread; i++) { outputs[simd_lid * BD + simd_gid] = o[i]; threadgroup_barrier(mem_flags::mem_threadgroup); - o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score; + o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor); + o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score); threadgroup_barrier(mem_flags::mem_threadgroup); } diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 324316507..e5bd3ad72 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -732,10 +732,7 @@ array scaled_dot_product_attention( } if (mask.dtype() == bool_) { scores = where( - mask, - scores, - array(-std::numeric_limits::infinity(), scores.dtype()), - s); + mask, scores, array(finfo(scores.dtype()).min, scores.dtype()), s); } else { scores = add(scores, mask, s); } diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 52ecc9be0..19af012c6 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -38,7 +38,7 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None, sinks=None): mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats)) if mask.dtype == mx.bool_: - scores = mx.where(mask, scores, -np.float32(np.inf)) + scores = mx.where(mask, scores, mx.finfo(scores.dtype).min) else: scores += mask @@ -410,18 +410,15 @@ class TestFastSDPA(mlx_tests.MLXTestCase): def test_fully_masked(self): Lkv = 8 - masks = [mx.array(False), mx.array(-float("inf"))] - for mask in masks: - for D in [4, 128]: - for Lq in [1, 8]: - q = mx.random.normal(shape=(1, 4, Lq, D)) - k = mx.random.normal(shape=(1, 4, Lkv, D)) - v = mx.random.normal(shape=(1, 4, Lkv, D)) + mask = mx.array(False) + for D in [128]: + for Lq in [1, 8, 32]: + q = mx.random.normal(shape=(1, 4, Lq, D)) + k = mx.random.normal(shape=(1, 4, Lkv, D)) + v = mx.random.normal(shape=(1, 4, Lkv, D)) - out = mx.fast.scaled_dot_product_attention( - q, k, v, mask=mask, scale=1 - ) - self.assertTrue(mx.all(mx.isnan(out))) + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1) + self.assertFalse(mx.any(mx.isnan(out))) def test_inf_score(self): Lkv = 8