diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cu b/mlx/backend/cuda/scaled_dot_product_attention.cu index eb1f248d5..2095bdb43 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cu +++ b/mlx/backend/cuda/scaled_dot_product_attention.cu @@ -46,6 +46,7 @@ __global__ void kernel_sdpav_1pass( const T* K, const T* V, T* O, + const T* sinks, __grid_constant__ const AttnParams params) { constexpr int BN = 32; constexpr int BD = 32; @@ -65,7 +66,7 @@ __global__ void kernel_sdpav_1pass( __shared__ U max_scores[BN]; __shared__ U sum_exp_scores[BN]; - const U scale_log2 = params.scale * 1.44269504089f; + const U scale_log2 = params.scale * M_LOG2E; auto block = cg::this_thread_block(); auto warp = cg::tiled_partition<32>(block); @@ -110,6 +111,10 @@ __global__ void kernel_sdpav_1pass( U max_score = -INFINITY; U sum_exp_score = 0.f; + if (sinks && warp_idx == 0) { + max_score = M_LOG2E * static_cast(sinks[head_idx]); + sum_exp_score = 1.f; + } // For each key for (int i = kv_seq_idx; i < params.kL; i += BN) { @@ -137,8 +142,9 @@ __global__ void kernel_sdpav_1pass( // Update the accumulators U new_max = max(max_score, score); - U factor = exp2f(max_score - new_max); - U exp_score = exp2f(score - new_max); + bool is_neg_inf = new_max == -INFINITY; + U factor = is_neg_inf ? 1 : exp2f(max_score - new_max); + U exp_score = is_neg_inf ? 0 : exp2f(score - new_max); max_score = new_max; sum_exp_score = sum_exp_score * factor + exp_score; @@ -193,6 +199,7 @@ __global__ void kernel_sdpav_2pass_1( const T* Q, const T* K, const T* V, + const T* sinks, float* partials, float* sums, float* maxs, @@ -268,8 +275,12 @@ __global__ void kernel_sdpav_2pass_1( o[i] = 0.f; } - U max_score = -1e9; + U max_score = -INFINITY; U sum_exp_score = 0.f; + if (sinks && warp_idx == 0 && block_idx == 0) { + max_score = M_LOG2E * static_cast(sinks[head_idx]); + sum_exp_score = 1.f; + } // For each key for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) { @@ -297,8 +308,9 @@ __global__ void kernel_sdpav_2pass_1( // Update the accumulators U new_max = max(max_score, score); - U factor = exp2f(max_score - new_max); - U exp_score = exp2f(score - new_max); + bool is_neg_inf = new_max == -INFINITY; + U factor = is_neg_inf ? 1 : exp2f(max_score - new_max); + U exp_score = is_neg_inf ? 0 : exp2f(score - new_max); max_score = new_max; sum_exp_score = sum_exp_score * factor + exp_score; @@ -463,10 +475,14 @@ void sdpa_vector_1pass_fallback( const array& v, const float scale, array& o, - bool do_causal_ = false) { + bool do_causal, + const std::optional& sinks) { encoder.set_input_array(q); encoder.set_input_array(k); encoder.set_input_array(v); + if (sinks) { + encoder.set_input_array(*sinks); + } encoder.set_output_array(o); cu::AttnParams params{ @@ -489,7 +505,7 @@ void sdpa_vector_1pass_fallback( dim3 block_dim(1024, 1, 1); dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) { - dispatch_bool(do_causal_, [&](auto do_causal) { + dispatch_bool(do_causal, [&](auto do_causal) { dispatch_headdim(params.D, [&](auto headdim) { using DataType = cuda_type_t; @@ -504,6 +520,7 @@ void sdpa_vector_1pass_fallback( k.data(), v.data(), o.data(), + sinks ? (*sinks).data() : nullptr, params); }); }); @@ -518,7 +535,8 @@ void sdpa_vector_2pass_fallback( const array& v, const float scale, array& o, - bool do_causal_ = false) { + bool do_causal, + const std::optional& sinks) { cu::AttnParams params{ /* int B = */ q.shape(0), /* int H = */ q.shape(1), @@ -559,7 +577,7 @@ void sdpa_vector_2pass_fallback( encoder.add_temporary(maxs); dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) { - dispatch_bool(do_causal_, [&](auto do_causal) { + dispatch_bool(do_causal, [&](auto do_causal) { dispatch_headdim(params.D, [&](auto headdim) { using DataType = cuda_type_t; @@ -570,6 +588,10 @@ void sdpa_vector_2pass_fallback( encoder.set_input_array(q); encoder.set_input_array(k); encoder.set_input_array(v); + if (sinks) { + encoder.set_input_array(*sinks); + } + encoder.set_output_array(intermediate); encoder.set_output_array(sums); encoder.set_output_array(maxs); @@ -585,6 +607,7 @@ void sdpa_vector_2pass_fallback( q.data(), k.data(), v.data(), + sinks ? (*sinks).data() : nullptr, intermediate.data(), sums.data(), maxs.data(), @@ -627,15 +650,16 @@ void sdpa_vector_fallback( const array& v, const float scale, array& o, - bool do_causal_ = false) { + bool do_causal, + const std::optional& sinks) { int kL = k.shape(2); if (kL > 1024) { return sdpa_vector_2pass_fallback( - s, encoder, q, k, v, scale, o, do_causal_); + s, encoder, q, k, v, scale, o, do_causal, sinks); } else { return sdpa_vector_1pass_fallback( - s, encoder, q, k, v, scale, o, do_causal_); + s, encoder, q, k, v, scale, o, do_causal, sinks); } } @@ -691,7 +715,7 @@ void ScaledDotProductAttention::eval_gpu( // Define some copy functions to ensure the layout of the inputs is as // expected. - copies.reserve(3); + copies.reserve(inputs.size()); auto copy_unless = [&copies, &s]( auto predicate, const array& arr) -> const array& { if (!predicate(arr)) { @@ -703,6 +727,16 @@ void ScaledDotProductAttention::eval_gpu( } }; + // Checks that the headdim dimension has stride 1. + auto is_matrix_contiguous = [](const array& arr) { + return arr.strides(-1) == 1; + }; + + std::optional sinks = std::nullopt; + if (has_sinks_) { + sinks = copy_unless(is_matrix_contiguous, inputs.back()); + } + // We are in vector mode ie single query if (q_pre.shape(2) < 4) { auto q_copy_unless = [](const array& arr) { @@ -740,10 +774,6 @@ void ScaledDotProductAttention::eval_gpu( const auto& k = copy_unless(kv_copy_unless, k_pre); const auto& v = copy_unless(kv_copy_unless, v_pre); - for (const auto& cp : copies) { - encoder.add_temporary(cp); - } - // Donate the query if possible if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) { o.copy_shared_buffer(q); @@ -752,22 +782,26 @@ void ScaledDotProductAttention::eval_gpu( int64_t str_oH = o.shape(3); int64_t str_oL = o.shape(1) * str_oH; int64_t str_oB = o.shape(2) * str_oL; - size_t data_size = o.shape(0) * str_oB; array::Flags flags{ /* bool contiguous = */ 1, /* bool row_contiguous = */ o.shape(2) == 1, - /* bool col_contiguous = */ 0, + /* bool col_contiguous = */ o.size() == o.shape(3), }; o.set_data( allocator::malloc(o.nbytes()), - data_size, + o.size(), {str_oB, str_oH, str_oL, str_oD}, flags); } - return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_); + for (const auto& cp : copies) { + encoder.add_temporary(cp); + } + + return sdpa_vector_fallback( + s, encoder, q, k, v, scale_, o, do_causal_, sinks); } // Full attention mode should never reach here diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 8258e9c14..b7ded1a69 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -9,6 +9,7 @@ constant bool query_transposed [[function_constant(21)]]; constant bool do_causal [[function_constant(22)]]; constant bool bool_mask [[function_constant(23)]]; constant bool float_mask [[function_constant(24)]]; +constant bool has_sinks [[function_constant(25)]]; template [[kernel]] void sdpa_vector( @@ -31,6 +32,9 @@ template [[buffer(14), function_constant(has_mask)]], const constant int& mask_head_stride [[buffer(15), function_constant(has_mask)]], + const device T* sinks [[buffer(16), function_constant(has_sinks)]], + const constant int& num_q_heads + [[buffer(17), function_constant(has_sinks)]], uint3 tid [[threadgroup_position_in_grid]], uint3 tpg [[threadgroups_per_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -53,24 +57,24 @@ template threadgroup U sum_exp_scores[BN]; // Adjust positions - const int head_idx = tid.x; + const int q_batch_head_idx = tid.x; const int q_seq_idx = tid.y; - const int kv_head_idx = head_idx / gqa_factor; - const int o_offset = head_idx * tpg.y + q_seq_idx; + const int kv_head_idx = q_batch_head_idx / gqa_factor; + const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx; const int q_offset = - query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset; + query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset; queries += q_offset * D + simd_lid * qk_per_thread; keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride + simd_lid * qk_per_thread; values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride + simd_lid * v_per_thread; if (bool_mask) { - bmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + - q_seq_idx * mask_q_seq_stride; + bmask += q_batch_head_idx * mask_head_stride + + simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; } if (float_mask) { - fmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + - q_seq_idx * mask_q_seq_stride; + fmask += q_batch_head_idx * mask_head_stride + + simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; } out += o_offset * V + simd_gid * v_per_thread; @@ -85,6 +89,10 @@ template U max_score = -INFINITY; U sum_exp_score = 0; + if (has_sinks && simd_gid == 0) { + max_score = static_cast(sinks[q_batch_head_idx % num_q_heads]); + sum_exp_score = 1; + } // For each key for (int i = simd_gid; i < N; i += BN) { @@ -93,6 +101,8 @@ template use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); } else if (bool_mask) { use_key = bmask[0]; + } else if (float_mask) { + use_key = (fmask[0] >= Limits::finite_min); } if (use_key) { // Read the key @@ -107,13 +117,14 @@ template } score = simd_sum(score); if (float_mask) { - score += max(Limits::finite_min, static_cast(fmask[0])); + score += static_cast(fmask[0]); } // Update the accumulators U new_max = max(max_score, score); - U factor = fast::exp(max_score - new_max); - U exp_score = fast::exp(score - new_max); + bool is_neg_inf = new_max == -INFINITY; + U factor = is_neg_inf ? 1.0 : fast::exp(max_score - new_max); + U exp_score = is_neg_inf ? 0.0 : fast::exp(score - new_max); max_score = new_max; sum_exp_score = sum_exp_score * factor + exp_score; @@ -187,6 +198,9 @@ template [[buffer(16), function_constant(has_mask)]], const constant int& mask_head_stride [[buffer(17), function_constant(has_mask)]], + const device T* sinks [[buffer(18), function_constant(has_sinks)]], + const constant int& num_q_heads + [[buffer(19), function_constant(has_sinks)]], uint3 tid [[threadgroup_position_in_grid]], uint3 tpg [[threadgroups_per_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -211,12 +225,12 @@ template // Adjust positions const int block_idx = tid.z; - const int head_idx = tid.x; + const int q_batch_head_idx = tid.x; const int q_seq_idx = tid.y; - const int o_offset = head_idx * tpg.y + q_seq_idx; + const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx; const int q_offset = - query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset; - const int kv_head_idx = head_idx / gqa_factor; + query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset; + const int kv_head_idx = q_batch_head_idx / gqa_factor; queries += q_offset * D + simd_lid * qk_per_thread; keys += kv_head_idx * k_head_stride + @@ -225,12 +239,12 @@ template (block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread; out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread; if (bool_mask) { - bmask += head_idx * mask_head_stride + + bmask += q_batch_head_idx * mask_head_stride + (block_idx * BN + simd_gid) * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; } if (float_mask) { - fmask += head_idx * mask_head_stride + + fmask += q_batch_head_idx * mask_head_stride + (block_idx * BN + simd_gid) * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; } @@ -245,8 +259,13 @@ template o[i] = 0; } - U max_score = -1e9; + U max_score = -INFINITY; U sum_exp_score = 0; + if (has_sinks && block_idx == 0 && simd_gid == 0) { + int q_head_idx = q_batch_head_idx % num_q_heads; + max_score = static_cast(sinks[q_head_idx]); + sum_exp_score = 1; + } // For each key for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) { @@ -255,6 +274,8 @@ template use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); } else if (bool_mask) { use_key = bmask[0]; + } else if (float_mask) { + use_key = (fmask[0] >= Limits::finite_min); } if (use_key) { // Read the key @@ -268,6 +289,10 @@ template score += q[i] * k[i]; } score = simd_sum(score); + if (score < Limits::finite_min) { + continue; + } + if (float_mask) { score += fmask[0]; } diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h index 34d5bf58a..7397039b5 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h @@ -11,6 +11,7 @@ constant bool align_K [[function_constant(201)]]; constant bool has_mask [[function_constant(300)]]; constant bool do_causal [[function_constant(301)]]; +constant bool has_sinks [[function_constant(302)]]; template struct TransformScale { @@ -82,6 +83,7 @@ template < const constant AttnParams* params [[buffer(4)]], const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]], const device MaskType* mask [[buffer(6), function_constant(has_mask)]], + const device T* sinks [[buffer(7), function_constant(has_sinks)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], @@ -169,7 +171,7 @@ template < VBlockLoader loader_v( V, params->V_strides[2], Vs, simd_group_id, simd_lane_id); - TransformScale ts(static_cast(params->scale * 1.44269504089)); + TransformScale ts(static_cast(params->scale * M_LOG2E_F)); // Prepare MMA tiles constexpr short kFragSize = 8; // MMAFrag size @@ -232,6 +234,14 @@ template < max_score[i] = Limits::finite_min; } + if (has_sinks) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = M_LOG2E_F * static_cast(sinks[tidl.y]); + sum_score[i] = 1; + } + } + int kb_lim = params->NK; if (do_causal) { @@ -350,7 +360,7 @@ template < Stile.frag_at(i, j)[jj] = mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf; } else { - Stile.frag_at(i, j)[jj] += 1.44269504089 * selem_t(mfrag[jj]); + Stile.frag_at(i, j)[jj] += M_LOG2E_F * selem_t(mfrag[jj]); } } } diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index cc27bff2d..0aca3170e 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -21,8 +21,9 @@ void sdpa_full_self_attention_metal( const array& v, const float scale, array& o, - bool do_causal_ = false, - const std::optional& mask = std::nullopt) { + bool do_causal_, + const std::optional& mask, + const std::optional& sinks) { using namespace mlx::steel; int wm = 4; @@ -42,35 +43,49 @@ void sdpa_full_self_attention_metal( const bool align_Q = (qL % bq) == 0; const bool align_K = (kL % bk) == 0; - const bool has_mask = !!mask; + const bool has_mask = mask.has_value(); const bool do_causal = do_causal_; + const bool has_sinks = sinks.has_value(); metal::MTLFCList func_consts = { {&align_Q, MTL::DataType::DataTypeBool, 200}, {&align_K, MTL::DataType::DataTypeBool, 201}, {&has_mask, MTL::DataType::DataTypeBool, 300}, - {&do_causal, MTL::DataType::DataTypeBool, 301}}; + {&do_causal, MTL::DataType::DataTypeBool, 301}, + {&has_sinks, MTL::DataType::DataTypeBool, 302}}; - std::ostringstream kname; - // clang-format off - kname << "steel_attention_" - << type_to_name(q) - << "_bq" << bq - << "_bk" << bk - << "_bd" << bd - << "_wm" << wm - << "_wn" << wn - << "_mask" << (type_to_name(has_mask ? *mask : q)); // clang-format on + std::string base_name; + concatenate( + base_name, + "steel_attention_", + type_to_name(q), + "_bq", + bq, + "_bk", + bk, + "_bd", + bd, + "_wm", + wm, + "_wn", + wn, + "_mask", + type_to_name(has_mask ? *mask : q)); - std::string base_name = kname.str(); - - // clang-format off - kname << "_align_Q_" << (align_Q ? 't' : 'n') - << "_align_K_" << (align_K ? 't' : 'n') - << "_has_mask_" << (has_mask ? 't' : 'n') - << "_do_causal_" << (do_causal ? 't' : 'n'); // clang-format on - - std::string hash_name = kname.str(); + std::string hash_name; + concatenate( + hash_name, + base_name, + "_align_Q_", + (align_Q ? 't' : 'n'), + "_align_K_", + (align_K ? 't' : 'n'), + "_has_mask_", + (has_mask ? 't' : 'n'), + "_do_causal_", + (do_causal ? 't' : 'n'), + "_has_sinks_", + (has_sinks ? 't' : 'n')); auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(base_name, hash_name, func_consts); @@ -114,8 +129,8 @@ void sdpa_full_self_attention_metal( compute_encoder.set_output_array(o, 3); compute_encoder.set_bytes(params, 4); - if (mask) { - auto m = *mask; + if (has_mask) { + auto& m = *mask; AttnMaskParams mask_params{/* int64_t M_strides[3] = */ { m.strides(0), m.strides(1), m.strides(2)}}; @@ -123,6 +138,9 @@ void sdpa_full_self_attention_metal( compute_encoder.set_bytes(mask_params, 5); compute_encoder.set_input_array(m, 6); } + if (has_sinks) { + compute_encoder.set_input_array(*sinks, 7); + } MTL::Size grid_dims = MTL::Size(NQ, H, B); MTL::Size group_dims = MTL::Size(32, wm, wn); @@ -139,7 +157,8 @@ void sdpa_vector( array& out, float scale, bool do_causal, - const std::optional& mask) { + const std::optional& mask, + const std::optional& sinks) { // Set the kernel name std::string kname; kname.reserve(64); @@ -153,30 +172,32 @@ void sdpa_vector( // Compute the necessary sizes int gqa_factor = q.shape(1) / k.shape(1); int N = k.shape(2); - int B = q.shape(0) * q.shape(1); size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); size_t k_seq_stride = k.strides()[2]; size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1); size_t v_seq_stride = v.strides()[2]; MTL::Size group_dims(1024, 1, 1); - MTL::Size grid_dims(B, q.shape(2), 1); + MTL::Size grid_dims(q.shape(0) * q.shape(1), q.shape(2), 1); bool has_mask = mask.has_value(); bool bool_mask = has_mask && (*mask).dtype() == bool_; bool float_mask = has_mask && !bool_mask; bool query_transposed = !q.flags().row_contiguous; + bool has_sinks = sinks.has_value(); metal::MTLFCList func_consts = { {&has_mask, MTL::DataType::DataTypeBool, 20}, {&query_transposed, MTL::DataType::DataTypeBool, 21}, {&do_causal, MTL::DataType::DataTypeBool, 22}, {&bool_mask, MTL::DataType::DataTypeBool, 23}, {&float_mask, MTL::DataType::DataTypeBool, 24}, + {&has_sinks, MTL::DataType::DataTypeBool, 25}, }; std::string hash_name = kname; hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask"; hash_name += query_transposed ? "_qt" : "_qnt"; hash_name += do_causal ? "_c" : "_nc"; + hash_name += has_sinks ? "_sinks" : "_nosinks"; // Get the kernel auto& compute_encoder = d.get_command_encoder(s.index); @@ -207,6 +228,10 @@ void sdpa_vector( compute_encoder.set_bytes(q_seq_stride, 14); compute_encoder.set_bytes(head_stride, 15); } + if (has_sinks) { + compute_encoder.set_input_array(*sinks, 16); + compute_encoder.set_bytes(q.shape(1), 17); + } // Launch compute_encoder.dispatch_threadgroups(grid_dims, group_dims); @@ -221,7 +246,8 @@ void sdpa_vector_2pass( array& out, float scale, bool do_causal, - const std::optional& mask) { + const std::optional& mask, + const std::optional& sinks) { // Set the kernel name std::string kname; kname.reserve(64); @@ -267,17 +293,20 @@ void sdpa_vector_2pass( bool bool_mask = has_mask && (*mask).dtype() == bool_; bool float_mask = has_mask && !bool_mask; bool query_transposed = !q.flags().row_contiguous; + bool has_sinks = sinks.has_value(); metal::MTLFCList func_consts = { {&has_mask, MTL::DataType::DataTypeBool, 20}, {&query_transposed, MTL::DataType::DataTypeBool, 21}, {&do_causal, MTL::DataType::DataTypeBool, 22}, {&bool_mask, MTL::DataType::DataTypeBool, 23}, {&float_mask, MTL::DataType::DataTypeBool, 24}, + {&has_sinks, MTL::DataType::DataTypeBool, 25}, }; std::string hash_name = kname; hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask"; hash_name += query_transposed ? "_qt" : "_qnt"; hash_name += do_causal ? "_c" : "_nc"; + hash_name += has_sinks ? "_sinks" : "_nosinks"; // Get the kernel auto& compute_encoder = d.get_command_encoder(s.index); @@ -310,6 +339,10 @@ void sdpa_vector_2pass( compute_encoder.set_bytes(q_seq_stride, 16); compute_encoder.set_bytes(head_stride, 17); } + if (has_sinks) { + compute_encoder.set_input_array(*sinks, 18); + compute_encoder.set_bytes(q.shape(1), 19); + } // Launch compute_encoder.dispatch_threadgroups(grid_dims, group_dims); @@ -411,6 +444,12 @@ void ScaledDotProductAttention::eval_gpu( return arr.strides(-1) == 1; }; + std::optional sinks = std::nullopt; + if (has_sinks_) { + sinks = copy_unless(is_matrix_contiguous, inputs.back()); + } + bool has_arr_mask = inputs.size() > (3 + has_sinks_); + // We are in vector mode ie single query if (q_pre.shape(2) <= 8) { auto q_copy_unless = [](const array& arr) { @@ -462,7 +501,7 @@ void ScaledDotProductAttention::eval_gpu( (strides[0] == strides[1] * shape[1]); }; - auto mask = inputs.size() > 3 + auto mask = has_arr_mask ? std::optional{copy_unless(mask_copy_unless, inputs[3])} : std::nullopt; @@ -473,9 +512,9 @@ void ScaledDotProductAttention::eval_gpu( char devc = d.get_architecture().back(); if ((devc == 'd' && k.shape(2) >= 1024) || (k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) { - sdpa_vector_2pass(s, d, q, k, v, o, scale_, do_causal, mask); + sdpa_vector_2pass(s, d, q, k, v, o, scale_, do_causal, mask, sinks); } else { - sdpa_vector(s, d, q, k, v, o, scale_, do_causal, mask); + sdpa_vector(s, d, q, k, v, o, scale_, do_causal, mask, sinks); } } @@ -503,11 +542,12 @@ void ScaledDotProductAttention::eval_gpu( {str_oB, str_oH, str_oL, str_oD}, flags); - auto mask = inputs.size() > 3 + auto mask = has_arr_mask ? std::optional{copy_unless(is_matrix_contiguous, inputs[3])} : std::nullopt; - sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o, do_causal_, mask); + sdpa_full_self_attention_metal( + s, d, q, k, v, scale_, o, do_causal_, mask, sinks); } d.add_temporaries(std::move(copies), s.index); diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 254bbde77..324316507 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -579,6 +579,7 @@ array scaled_dot_product_attention( const float scale, const std::string& mask_mode /* = "" */, const std::vector& mask_arrs /* = {} */, + const std::optional& sinks /* = {} */, StreamOrDevice s /* = {}*/) { for (const auto& tensor : {queries, keys, values}) { if (tensor.ndim() != 4) { @@ -679,13 +680,20 @@ array scaled_dot_product_attention( << final_type << "."; throw std::invalid_argument(msg.str()); } + bool has_sinks = sinks.has_value(); auto q = astype(queries, final_type, s); auto k = astype(keys, final_type, s); auto v = astype(values, final_type, s); - auto fallback = [scale, final_type, n_q_heads, n_kv_heads, do_causal, s]( - const std::vector& inputs) { + auto fallback = [scale, + final_type, + n_q_heads, + n_kv_heads, + do_causal, + has_sinks, + has_arr_mask, + s](const std::vector& inputs) { auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s); int n_repeats = n_q_heads / n_kv_heads; int B = q.shape(0); @@ -698,20 +706,22 @@ array scaled_dot_product_attention( v = expand_dims(v, 2, s); } auto scores = matmul(q, swapaxes(k, -1, -2, s), s); - if (inputs.size() > 3 || do_causal) { + if (has_arr_mask || do_causal) { // Mask must be broadcast-compatible with [B, n_q_heads, L_q, L_kv] - auto mask = inputs.back(); - - if (do_causal) { - int kL = k.shape(-2); - int qL = q.shape(-2); - int q_off = (kL - qL) < 0 ? 0 : (kL - qL); - auto q_idx = arange(q_off, q_off + qL, s); - auto k_idx = arange(0, kL, s); - q_idx = expand_dims(q_idx, 1, s); - k_idx = expand_dims(k_idx, 0, s); - mask = greater_equal(q_idx, k_idx, s); - } + auto make_or_fetch_mask = [&]() { + if (do_causal) { + int kL = k.shape(-2); + int qL = q.shape(-2); + int q_off = (kL - qL) < 0 ? 0 : (kL - qL); + auto q_idx = arange(q_off, q_off + qL, s); + auto k_idx = arange(0, kL, s); + q_idx = expand_dims(q_idx, 1, s); + k_idx = expand_dims(k_idx, 0, s); + return greater_equal(q_idx, k_idx, s); + } + return inputs[3]; + }; + auto mask = make_or_fetch_mask(); if (n_repeats > 1 && mask.ndim() >= 3) { if (mask.shape(-3) == 1) { @@ -730,7 +740,25 @@ array scaled_dot_product_attention( scores = add(scores, mask, s); } } + if (has_sinks) { + auto sinks = inputs.back(); + // scores has shape B N_q N_k L_q L_k + sinks = expand_dims(sinks, {0, 2, 3}, s); + if (scores.ndim() == 5) { + sinks = unflatten(sinks, 1, {n_kv_heads, n_repeats}, s); + } + auto bsx_shape = scores.shape(); + bsx_shape.back() = 1; + scores = concatenate({broadcast_to(sinks, bsx_shape, s), scores}, -1, s); + } scores = softmax(scores, std::vector{-1}, true, s); + if (has_sinks) { + // Slice off scores + auto start = Shape(scores.ndim(), 0); + start.back() = 1; + auto stop = scores.shape(); + scores = slice(scores, std::move(start), std::move(stop), s); + } auto out = matmul(scores, v, s); if (n_repeats > 1) { out = flatten(out, 1, 2, s); @@ -746,7 +774,7 @@ array scaled_dot_product_attention( has_bool_mask = mask_arr.dtype() == bool_; if (promote_types(mask_arr.dtype(), final_type) != final_type) { std::ostringstream msg; - msg << "[scaled_dot_product_attention] Mask type must promote to output type. " + msg << "[scaled_dot_product_attention] Mask type must promote to output type " << final_type << "."; throw std::invalid_argument(msg.str()); } else if (!has_bool_mask) { @@ -757,6 +785,22 @@ array scaled_dot_product_attention( mask_shape.back() = keys.shape(-2); inputs.push_back(broadcast_to(mask_arr, mask_shape, stream)); } + if (has_sinks) { + if (promote_types(sinks->dtype(), final_type) != final_type) { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] Type of sinks must promote to output type " + << final_type << "."; + throw std::invalid_argument(msg.str()); + } + if (sinks->ndim() != 1 || sinks->shape(0) != n_q_heads) { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] Received invalid shape for sinks " + << sinks->shape() << "."; + throw std::invalid_argument(msg.str()); + } + inputs.push_back(astype(*sinks, final_type, stream)); + } + if (!ScaledDotProductAttention::use_fallback( q, k, v, has_mask, has_arr_mask, do_causal, stream)) { auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}; @@ -764,7 +808,7 @@ array scaled_dot_product_attention( std::move(out_shape), final_type, std::make_shared( - stream, fallback, scale, do_causal), + stream, fallback, scale, do_causal, has_sinks), std::move(inputs)); } return fallback(std::move(inputs))[0]; @@ -773,7 +817,8 @@ array scaled_dot_product_attention( bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const { const ScaledDotProductAttention& a_other = static_cast(other); - return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_; + return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_ && + has_sinks_ == a_other.has_sinks_; } bool Quantize::is_equivalent(const Primitive& other) const { diff --git a/mlx/fast.h b/mlx/fast.h index 10f9ced96..3cbff60e4 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -50,6 +50,7 @@ array scaled_dot_product_attention( const float scale, const std::string& mask_mode = "", const std::vector& mask_arrs = {}, + const std::optional& sinks = {}, StreamOrDevice s = {}); using TemplateArg = std::variant; diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index fd6ba8fed..a8000485a 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -208,9 +208,13 @@ class ScaledDotProductAttention : public Custom { explicit ScaledDotProductAttention( Stream stream, std::function(std::vector)> fallback, - const float scale, - const bool do_causal) - : Custom(stream, fallback), scale_(scale), do_causal_(do_causal) {} + float scale, + bool do_causal, + bool has_sinks) + : Custom(stream, fallback), + scale_(scale), + do_causal_(do_causal), + has_sinks_(has_sinks) {} static bool use_fallback( const array& q, @@ -237,12 +241,13 @@ class ScaledDotProductAttention : public Custom { DEFINE_NAME(ScaledDotProductAttention); DEFINE_INPUT_OUTPUT_SHAPE() auto state() const { - return std::make_tuple(nullptr, scale_, do_causal_); + return std::make_tuple(nullptr, scale_, do_causal_, has_sinks_); } private: float scale_; bool do_causal_; + bool has_sinks_; }; class Quantize : public Custom { diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 7b484559f..0ed1aa698 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -196,6 +196,7 @@ void init_fast(nb::module_& parent_module) { const mx::array& values, const float scale, const std::variant& mask, + const std::optional& sinks, mx::StreamOrDevice s) { bool has_mask = !std::holds_alternative(mask); bool has_str_mask = @@ -212,16 +213,16 @@ void init_fast(nb::module_& parent_module) { throw std::invalid_argument(msg.str()); } return mx::fast::scaled_dot_product_attention( - queries, keys, values, scale, mask_str, {}, s); + queries, keys, values, scale, mask_str, {}, sinks, s); } else { auto mask_arr = std::get(mask); return mx::fast::scaled_dot_product_attention( - queries, keys, values, scale, "", {mask_arr}, s); + queries, keys, values, scale, "", {mask_arr}, sinks, s); } } else { return mx::fast::scaled_dot_product_attention( - queries, keys, values, scale, "", {}, s); + queries, keys, values, scale, "", {}, sinks, s); } }, "q"_a, @@ -230,9 +231,10 @@ void init_fast(nb::module_& parent_module) { nb::kw_only(), "scale"_a, "mask"_a = nb::none(), + "sinks"_a = nb::none(), "stream"_a = nb::none(), nb::sig( - "def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, str, array] = None, stream: Union[None, Stream, Device] = None) -> array"), + "def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, str, array] = None, sinks: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``. @@ -262,14 +264,17 @@ void init_fast(nb::module_& parent_module) { q (array): Queries with shape ``[B, N_q, T_q, D]``. k (array): Keys with shape ``[B, N_kv, T_kv, D]``. v (array): Values with shape ``[B, N_kv, T_kv, D]``. - scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``) - mask (Union[None, str, array], optional): The mask to apply to the + scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``). + mask (str or array, optional): The mask to apply to the query-key scores. The mask can be an array or a string indicating the mask type. The only supported string type is ``"causal"``. If the mask is an array it can be a boolean or additive mask. The mask can have at most 4 dimensions and must be broadcast-compatible with the shape ``[B, N, T_q, T_kv]``. If an additive mask is given its type must promote to the promoted type of ``q``, ``k``, and ``v``. + sinks (array, optional): An optional array of attention sinks. + Default: ``None``. + Returns: array: The output array. diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index abc9ada9d..52ecc9be0 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -6,7 +6,7 @@ import mlx_tests import numpy as np -def mlx_ref_attn(q, k, v, scale=1.0, mask=None): +def mlx_ref_attn(q, k, v, scale=1.0, mask=None, sinks=None): q_dtype = q.dtype q = q * mx.array(scale, q_dtype) n_q_heads = q.shape[-3] @@ -23,7 +23,6 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None): v = mx.expand_dims(v, 2) scores = q @ mx.swapaxes(k, -1, -2) - if mask is not None: if mask == "causal": @@ -43,7 +42,18 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None): else: scores += mask + if sinks is not None: + sinks = mx.expand_dims(sinks, (0, 2, 3)) + if n_repeats > 1: + sinks = mx.unflatten(sinks, 1, (n_kv_heads, n_repeats)) + score_shape = list(scores.shape) + score_shape[-1] = 1 + sinks = mx.broadcast_to(sinks, score_shape) + scores = mx.concatenate([sinks, scores], axis=-1) + scores = mx.softmax(scores, axis=-1, precise=True) + if sinks is not None: + scores = scores[..., 1:] out = scores @ v if n_repeats > 1: @@ -158,7 +168,7 @@ class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase): Dk = 64 - if self.is_apple_silicon: + if self.is_apple_silicon or mx.cuda.is_available(): dtypes.append(np.half) for SEQUENCE_LENGTH in [63, 129, 400]: @@ -230,7 +240,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase): B = 1 H = 32 dtypes = [np.float32] - if self.is_apple_silicon: + if self.is_apple_silicon or mx.cuda.is_available(): dtypes.append(np.half) for SEQUENCE_LENGTH in [1, 7, 9, 32, 63, 67, 129, 400, 2000]: @@ -400,15 +410,30 @@ class TestFastSDPA(mlx_tests.MLXTestCase): def test_fully_masked(self): Lkv = 8 - mask = mx.array(False) + masks = [mx.array(False), mx.array(-float("inf"))] + for mask in masks: + for D in [4, 128]: + for Lq in [1, 8]: + q = mx.random.normal(shape=(1, 4, Lq, D)) + k = mx.random.normal(shape=(1, 4, Lkv, D)) + v = mx.random.normal(shape=(1, 4, Lkv, D)) + + out = mx.fast.scaled_dot_product_attention( + q, k, v, mask=mask, scale=1 + ) + self.assertTrue(mx.all(mx.isnan(out))) + + def test_inf_score(self): + Lkv = 8 for D in [4, 128]: for Lq in [1, 8]: - q = mx.random.normal(shape=(1, 4, Lq, D)) - k = mx.random.normal(shape=(1, 4, Lkv, D)) + q = mx.ones(shape=(1, 4, Lq, D)) + k = mx.ones(shape=(1, 4, Lkv, D)) v = mx.random.normal(shape=(1, 4, Lkv, D)) - - out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1) - self.assertTrue(mx.all(mx.isnan(out))) + k[..., 0, :] = -float("inf") + ref = mlx_primitives_sdpa(q, k, v, scale=1, mask=None) + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) def test_fast_sdpa_few_query(self): D = 64 @@ -674,6 +699,51 @@ class TestSDPA(mlx_tests.MLXTestCase): self.assertFalse(mx.isnan(out).any().item()) self.assertLessEqual(mx.abs(out - expected).max().item(), 1e-4) + def test_sdpa_attention_sinks(self): + B = 2 + N_q = N_kv = 8 + T_q = T_kv = 128 + D = 64 + + q = mx.random.normal(shape=(B, N_q, T_q, D)) + k = mx.random.normal(shape=(B, N_kv, T_kv, D)) + v = mx.random.normal(shape=(B, N_kv, T_kv, D)) + scale = D**-0.5 + + # sinks should promote to correct type + sinks = mx.random.normal(shape=(N_q,)) + with self.assertRaises(ValueError): + mx.fast.scaled_dot_product_attention( + q.astype(mx.float16), + k.astype(mx.float16), + v.astype(mx.float16), + scale=scale, + sinks=sinks, + ) + + # Wrong shapes + sinks = mx.random.normal(shape=(N_q + 1,)) + with self.assertRaises(ValueError): + mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, sinks=sinks) + + sinks = mx.random.normal(shape=()) + with self.assertRaises(ValueError): + mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, sinks=sinks) + + for T_kv in [128, 4096]: + for T_q in [1, 128]: + for N_kv in [2, 8]: + q = mx.random.normal(shape=(B, N_q, T_q, D)) + k = mx.random.normal(shape=(B, N_kv, T_kv, D)) + v = mx.random.normal(shape=(B, N_kv, T_kv, D)) + sinks = 10 * mx.random.normal(shape=(N_q,)) + + expected = mlx_ref_attn(q, k, v, scale, sinks=sinks) + out = mx.fast.scaled_dot_product_attention( + q, k, v, scale=scale, sinks=sinks + ) + self.assertTrue(mx.allclose(out, expected, atol=1e-5)) + if __name__ == "__main__": mlx_tests.MLXTestRunner(failfast=True)