diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cu b/mlx/backend/cuda/scaled_dot_product_attention.cu index a8f2c6b77..2095bdb43 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cu +++ b/mlx/backend/cuda/scaled_dot_product_attention.cu @@ -46,6 +46,7 @@ __global__ void kernel_sdpav_1pass( const T* K, const T* V, T* O, + const T* sinks, __grid_constant__ const AttnParams params) { constexpr int BN = 32; constexpr int BD = 32; @@ -65,7 +66,7 @@ __global__ void kernel_sdpav_1pass( __shared__ U max_scores[BN]; __shared__ U sum_exp_scores[BN]; - const U scale_log2 = params.scale * 1.44269504089f; + const U scale_log2 = params.scale * M_LOG2E; auto block = cg::this_thread_block(); auto warp = cg::tiled_partition<32>(block); @@ -110,6 +111,10 @@ __global__ void kernel_sdpav_1pass( U max_score = -INFINITY; U sum_exp_score = 0.f; + if (sinks && warp_idx == 0) { + max_score = M_LOG2E * static_cast(sinks[head_idx]); + sum_exp_score = 1.f; + } // For each key for (int i = kv_seq_idx; i < params.kL; i += BN) { @@ -137,8 +142,9 @@ __global__ void kernel_sdpav_1pass( // Update the accumulators U new_max = max(max_score, score); - U factor = exp2f(max_score - new_max); - U exp_score = exp2f(score - new_max); + bool is_neg_inf = new_max == -INFINITY; + U factor = is_neg_inf ? 1 : exp2f(max_score - new_max); + U exp_score = is_neg_inf ? 0 : exp2f(score - new_max); max_score = new_max; sum_exp_score = sum_exp_score * factor + exp_score; @@ -193,6 +199,7 @@ __global__ void kernel_sdpav_2pass_1( const T* Q, const T* K, const T* V, + const T* sinks, float* partials, float* sums, float* maxs, @@ -268,8 +275,12 @@ __global__ void kernel_sdpav_2pass_1( o[i] = 0.f; } - U max_score = -1e9; + U max_score = -INFINITY; U sum_exp_score = 0.f; + if (sinks && warp_idx == 0 && block_idx == 0) { + max_score = M_LOG2E * static_cast(sinks[head_idx]); + sum_exp_score = 1.f; + } // For each key for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) { @@ -297,8 +308,9 @@ __global__ void kernel_sdpav_2pass_1( // Update the accumulators U new_max = max(max_score, score); - U factor = exp2f(max_score - new_max); - U exp_score = exp2f(score - new_max); + bool is_neg_inf = new_max == -INFINITY; + U factor = is_neg_inf ? 1 : exp2f(max_score - new_max); + U exp_score = is_neg_inf ? 0 : exp2f(score - new_max); max_score = new_max; sum_exp_score = sum_exp_score * factor + exp_score; @@ -468,6 +480,9 @@ void sdpa_vector_1pass_fallback( encoder.set_input_array(q); encoder.set_input_array(k); encoder.set_input_array(v); + if (sinks) { + encoder.set_input_array(*sinks); + } encoder.set_output_array(o); cu::AttnParams params{ @@ -505,6 +520,7 @@ void sdpa_vector_1pass_fallback( k.data(), v.data(), o.data(), + sinks ? (*sinks).data() : nullptr, params); }); }); @@ -572,6 +588,10 @@ void sdpa_vector_2pass_fallback( encoder.set_input_array(q); encoder.set_input_array(k); encoder.set_input_array(v); + if (sinks) { + encoder.set_input_array(*sinks); + } + encoder.set_output_array(intermediate); encoder.set_output_array(sums); encoder.set_output_array(maxs); @@ -587,6 +607,7 @@ void sdpa_vector_2pass_fallback( q.data(), k.data(), v.data(), + sinks ? (*sinks).data() : nullptr, intermediate.data(), sums.data(), maxs.data(),