diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cu b/mlx/backend/cuda/scaled_dot_product_attention.cu index 1f1a7c493..a8f2c6b77 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cu +++ b/mlx/backend/cuda/scaled_dot_product_attention.cu @@ -463,7 +463,8 @@ void sdpa_vector_1pass_fallback( const array& v, const float scale, array& o, - bool do_causal_ = false) { + bool do_causal, + const std::optional& sinks) { encoder.set_input_array(q); encoder.set_input_array(k); encoder.set_input_array(v); @@ -489,7 +490,7 @@ void sdpa_vector_1pass_fallback( dim3 block_dim(1024, 1, 1); dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) { - dispatch_bool(do_causal_, [&](auto do_causal) { + dispatch_bool(do_causal, [&](auto do_causal) { dispatch_headdim(params.D, [&](auto headdim) { using DataType = cuda_type_t; @@ -518,7 +519,8 @@ void sdpa_vector_2pass_fallback( const array& v, const float scale, array& o, - bool do_causal_ = false) { + bool do_causal, + const std::optional& sinks) { cu::AttnParams params{ /* int B = */ q.shape(0), /* int H = */ q.shape(1), @@ -559,7 +561,7 @@ void sdpa_vector_2pass_fallback( encoder.add_temporary(maxs); dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) { - dispatch_bool(do_causal_, [&](auto do_causal) { + dispatch_bool(do_causal, [&](auto do_causal) { dispatch_headdim(params.D, [&](auto headdim) { using DataType = cuda_type_t; @@ -627,15 +629,16 @@ void sdpa_vector_fallback( const array& v, const float scale, array& o, - bool do_causal_ = false) { + bool do_causal, + const std::optional& sinks) { int kL = k.shape(2); if (kL > 1024) { return sdpa_vector_2pass_fallback( - s, encoder, q, k, v, scale, o, do_causal_); + s, encoder, q, k, v, scale, o, do_causal, sinks); } else { return sdpa_vector_1pass_fallback( - s, encoder, q, k, v, scale, o, do_causal_); + s, encoder, q, k, v, scale, o, do_causal, sinks); } } @@ -703,6 +706,16 @@ void ScaledDotProductAttention::eval_gpu( } }; + // Checks that the headdim dimension has stride 1. + auto is_matrix_contiguous = [](const array& arr) { + return arr.strides(-1) == 1; + }; + + std::optional sinks = std::nullopt; + if (has_sinks_) { + sinks = copy_unless(is_matrix_contiguous, inputs.back()); + } + // We are in vector mode ie single query if (q_pre.shape(2) < 4) { auto q_copy_unless = [](const array& arr) { @@ -766,7 +779,8 @@ void ScaledDotProductAttention::eval_gpu( encoder.add_temporary(cp); } - return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_); + return sdpa_vector_fallback( + s, encoder, q, k, v, scale_, o, do_causal_, sinks); } // Full attention mode should never reach here diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 159d268a6..b7ded1a69 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -116,17 +116,15 @@ template score += q[j] * k[j]; } score = simd_sum(score); - if (score < Limits::finite_min) { - continue; - } if (float_mask) { score += static_cast(fmask[0]); } // Update the accumulators U new_max = max(max_score, score); - U factor = fast::exp(max_score - new_max); - U exp_score = fast::exp(score - new_max); + bool is_neg_inf = new_max == -INFINITY; + U factor = is_neg_inf ? 1.0 : fast::exp(max_score - new_max); + U exp_score = is_neg_inf ? 0.0 : fast::exp(score - new_max); max_score = new_max; sum_exp_score = sum_exp_score * factor + exp_score;