Add sdpa with sinks (#2558)

* add sdpa with sinks

* fix 2 pass

* fix matrix sdpa

* fix perf regression

* add to cuda (#2580)
This commit is contained in:
Awni Hannun
2025-09-10 14:53:00 -07:00
committed by GitHub
parent db5443e831
commit d6977f2a57
9 changed files with 349 additions and 114 deletions

View File

@@ -46,6 +46,7 @@ __global__ void kernel_sdpav_1pass(
const T* K,
const T* V,
T* O,
const T* sinks,
__grid_constant__ const AttnParams params) {
constexpr int BN = 32;
constexpr int BD = 32;
@@ -65,7 +66,7 @@ __global__ void kernel_sdpav_1pass(
__shared__ U max_scores[BN];
__shared__ U sum_exp_scores[BN];
const U scale_log2 = params.scale * 1.44269504089f;
const U scale_log2 = params.scale * M_LOG2E;
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<32>(block);
@@ -110,6 +111,10 @@ __global__ void kernel_sdpav_1pass(
U max_score = -INFINITY;
U sum_exp_score = 0.f;
if (sinks && warp_idx == 0) {
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
sum_exp_score = 1.f;
}
// For each key
for (int i = kv_seq_idx; i < params.kL; i += BN) {
@@ -137,8 +142,9 @@ __global__ void kernel_sdpav_1pass(
// Update the accumulators
U new_max = max(max_score, score);
U factor = exp2f(max_score - new_max);
U exp_score = exp2f(score - new_max);
bool is_neg_inf = new_max == -INFINITY;
U factor = is_neg_inf ? 1 : exp2f(max_score - new_max);
U exp_score = is_neg_inf ? 0 : exp2f(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
@@ -193,6 +199,7 @@ __global__ void kernel_sdpav_2pass_1(
const T* Q,
const T* K,
const T* V,
const T* sinks,
float* partials,
float* sums,
float* maxs,
@@ -268,8 +275,12 @@ __global__ void kernel_sdpav_2pass_1(
o[i] = 0.f;
}
U max_score = -1e9;
U max_score = -INFINITY;
U sum_exp_score = 0.f;
if (sinks && warp_idx == 0 && block_idx == 0) {
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
sum_exp_score = 1.f;
}
// For each key
for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) {
@@ -297,8 +308,9 @@ __global__ void kernel_sdpav_2pass_1(
// Update the accumulators
U new_max = max(max_score, score);
U factor = exp2f(max_score - new_max);
U exp_score = exp2f(score - new_max);
bool is_neg_inf = new_max == -INFINITY;
U factor = is_neg_inf ? 1 : exp2f(max_score - new_max);
U exp_score = is_neg_inf ? 0 : exp2f(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
@@ -463,10 +475,14 @@ void sdpa_vector_1pass_fallback(
const array& v,
const float scale,
array& o,
bool do_causal_ = false) {
bool do_causal,
const std::optional<array>& sinks) {
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
if (sinks) {
encoder.set_input_array(*sinks);
}
encoder.set_output_array(o);
cu::AttnParams params{
@@ -489,7 +505,7 @@ void sdpa_vector_1pass_fallback(
dim3 block_dim(1024, 1, 1);
dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) {
dispatch_bool(do_causal_, [&](auto do_causal) {
dispatch_bool(do_causal, [&](auto do_causal) {
dispatch_headdim(params.D, [&](auto headdim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
@@ -504,6 +520,7 @@ void sdpa_vector_1pass_fallback(
k.data<DataType>(),
v.data<DataType>(),
o.data<DataType>(),
sinks ? (*sinks).data<DataType>() : nullptr,
params);
});
});
@@ -518,7 +535,8 @@ void sdpa_vector_2pass_fallback(
const array& v,
const float scale,
array& o,
bool do_causal_ = false) {
bool do_causal,
const std::optional<array>& sinks) {
cu::AttnParams params{
/* int B = */ q.shape(0),
/* int H = */ q.shape(1),
@@ -559,7 +577,7 @@ void sdpa_vector_2pass_fallback(
encoder.add_temporary(maxs);
dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) {
dispatch_bool(do_causal_, [&](auto do_causal) {
dispatch_bool(do_causal, [&](auto do_causal) {
dispatch_headdim(params.D, [&](auto headdim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
@@ -570,6 +588,10 @@ void sdpa_vector_2pass_fallback(
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
if (sinks) {
encoder.set_input_array(*sinks);
}
encoder.set_output_array(intermediate);
encoder.set_output_array(sums);
encoder.set_output_array(maxs);
@@ -585,6 +607,7 @@ void sdpa_vector_2pass_fallback(
q.data<DataType>(),
k.data<DataType>(),
v.data<DataType>(),
sinks ? (*sinks).data<DataType>() : nullptr,
intermediate.data<float>(),
sums.data<float>(),
maxs.data<float>(),
@@ -627,15 +650,16 @@ void sdpa_vector_fallback(
const array& v,
const float scale,
array& o,
bool do_causal_ = false) {
bool do_causal,
const std::optional<array>& sinks) {
int kL = k.shape(2);
if (kL > 1024) {
return sdpa_vector_2pass_fallback(
s, encoder, q, k, v, scale, o, do_causal_);
s, encoder, q, k, v, scale, o, do_causal, sinks);
} else {
return sdpa_vector_1pass_fallback(
s, encoder, q, k, v, scale, o, do_causal_);
s, encoder, q, k, v, scale, o, do_causal, sinks);
}
}
@@ -691,7 +715,7 @@ void ScaledDotProductAttention::eval_gpu(
// Define some copy functions to ensure the layout of the inputs is as
// expected.
copies.reserve(3);
copies.reserve(inputs.size());
auto copy_unless = [&copies, &s](
auto predicate, const array& arr) -> const array& {
if (!predicate(arr)) {
@@ -703,6 +727,16 @@ void ScaledDotProductAttention::eval_gpu(
}
};
// Checks that the headdim dimension has stride 1.
auto is_matrix_contiguous = [](const array& arr) {
return arr.strides(-1) == 1;
};
std::optional<array> sinks = std::nullopt;
if (has_sinks_) {
sinks = copy_unless(is_matrix_contiguous, inputs.back());
}
// We are in vector mode ie single query
if (q_pre.shape(2) < 4) {
auto q_copy_unless = [](const array& arr) {
@@ -740,10 +774,6 @@ void ScaledDotProductAttention::eval_gpu(
const auto& k = copy_unless(kv_copy_unless, k_pre);
const auto& v = copy_unless(kv_copy_unless, v_pre);
for (const auto& cp : copies) {
encoder.add_temporary(cp);
}
// Donate the query if possible
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
o.copy_shared_buffer(q);
@@ -752,22 +782,26 @@ void ScaledDotProductAttention::eval_gpu(
int64_t str_oH = o.shape(3);
int64_t str_oL = o.shape(1) * str_oH;
int64_t str_oB = o.shape(2) * str_oL;
size_t data_size = o.shape(0) * str_oB;
array::Flags flags{
/* bool contiguous = */ 1,
/* bool row_contiguous = */ o.shape(2) == 1,
/* bool col_contiguous = */ 0,
/* bool col_contiguous = */ o.size() == o.shape(3),
};
o.set_data(
allocator::malloc(o.nbytes()),
data_size,
o.size(),
{str_oB, str_oH, str_oL, str_oD},
flags);
}
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
for (const auto& cp : copies) {
encoder.add_temporary(cp);
}
return sdpa_vector_fallback(
s, encoder, q, k, v, scale_, o, do_causal_, sinks);
}
// Full attention mode should never reach here

View File

@@ -9,6 +9,7 @@ constant bool query_transposed [[function_constant(21)]];
constant bool do_causal [[function_constant(22)]];
constant bool bool_mask [[function_constant(23)]];
constant bool float_mask [[function_constant(24)]];
constant bool has_sinks [[function_constant(25)]];
template <typename T, int D, int V = D>
[[kernel]] void sdpa_vector(
@@ -31,6 +32,9 @@ template <typename T, int D, int V = D>
[[buffer(14), function_constant(has_mask)]],
const constant int& mask_head_stride
[[buffer(15), function_constant(has_mask)]],
const device T* sinks [[buffer(16), function_constant(has_sinks)]],
const constant int& num_q_heads
[[buffer(17), function_constant(has_sinks)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -53,24 +57,24 @@ template <typename T, int D, int V = D>
threadgroup U sum_exp_scores[BN];
// Adjust positions
const int head_idx = tid.x;
const int q_batch_head_idx = tid.x;
const int q_seq_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor;
const int o_offset = head_idx * tpg.y + q_seq_idx;
const int kv_head_idx = q_batch_head_idx / gqa_factor;
const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx;
const int q_offset =
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset;
queries += q_offset * D + simd_lid * qk_per_thread;
keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
simd_lid * qk_per_thread;
values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
simd_lid * v_per_thread;
if (bool_mask) {
bmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
bmask += q_batch_head_idx * mask_head_stride +
simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride;
}
if (float_mask) {
fmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
fmask += q_batch_head_idx * mask_head_stride +
simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride;
}
out += o_offset * V + simd_gid * v_per_thread;
@@ -85,6 +89,10 @@ template <typename T, int D, int V = D>
U max_score = -INFINITY;
U sum_exp_score = 0;
if (has_sinks && simd_gid == 0) {
max_score = static_cast<U>(sinks[q_batch_head_idx % num_q_heads]);
sum_exp_score = 1;
}
// For each key
for (int i = simd_gid; i < N; i += BN) {
@@ -93,6 +101,8 @@ template <typename T, int D, int V = D>
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
} else if (bool_mask) {
use_key = bmask[0];
} else if (float_mask) {
use_key = (fmask[0] >= Limits<T>::finite_min);
}
if (use_key) {
// Read the key
@@ -107,13 +117,14 @@ template <typename T, int D, int V = D>
}
score = simd_sum(score);
if (float_mask) {
score += max(Limits<U>::finite_min, static_cast<U>(fmask[0]));
score += static_cast<U>(fmask[0]);
}
// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
bool is_neg_inf = new_max == -INFINITY;
U factor = is_neg_inf ? 1.0 : fast::exp(max_score - new_max);
U exp_score = is_neg_inf ? 0.0 : fast::exp(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
@@ -187,6 +198,9 @@ template <typename T, int D, int V = D>
[[buffer(16), function_constant(has_mask)]],
const constant int& mask_head_stride
[[buffer(17), function_constant(has_mask)]],
const device T* sinks [[buffer(18), function_constant(has_sinks)]],
const constant int& num_q_heads
[[buffer(19), function_constant(has_sinks)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -211,12 +225,12 @@ template <typename T, int D, int V = D>
// Adjust positions
const int block_idx = tid.z;
const int head_idx = tid.x;
const int q_batch_head_idx = tid.x;
const int q_seq_idx = tid.y;
const int o_offset = head_idx * tpg.y + q_seq_idx;
const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx;
const int q_offset =
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
const int kv_head_idx = head_idx / gqa_factor;
query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset;
const int kv_head_idx = q_batch_head_idx / gqa_factor;
queries += q_offset * D + simd_lid * qk_per_thread;
keys += kv_head_idx * k_head_stride +
@@ -225,12 +239,12 @@ template <typename T, int D, int V = D>
(block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread;
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
if (bool_mask) {
bmask += head_idx * mask_head_stride +
bmask += q_batch_head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
if (float_mask) {
fmask += head_idx * mask_head_stride +
fmask += q_batch_head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
@@ -245,8 +259,13 @@ template <typename T, int D, int V = D>
o[i] = 0;
}
U max_score = -1e9;
U max_score = -INFINITY;
U sum_exp_score = 0;
if (has_sinks && block_idx == 0 && simd_gid == 0) {
int q_head_idx = q_batch_head_idx % num_q_heads;
max_score = static_cast<U>(sinks[q_head_idx]);
sum_exp_score = 1;
}
// For each key
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
@@ -255,6 +274,8 @@ template <typename T, int D, int V = D>
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
} else if (bool_mask) {
use_key = bmask[0];
} else if (float_mask) {
use_key = (fmask[0] >= Limits<T>::finite_min);
}
if (use_key) {
// Read the key
@@ -268,6 +289,10 @@ template <typename T, int D, int V = D>
score += q[i] * k[i];
}
score = simd_sum(score);
if (score < Limits<T>::finite_min) {
continue;
}
if (float_mask) {
score += fmask[0];
}

View File

@@ -11,6 +11,7 @@ constant bool align_K [[function_constant(201)]];
constant bool has_mask [[function_constant(300)]];
constant bool do_causal [[function_constant(301)]];
constant bool has_sinks [[function_constant(302)]];
template <typename T>
struct TransformScale {
@@ -82,6 +83,7 @@ template <
const constant AttnParams* params [[buffer(4)]],
const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],
const device MaskType* mask [[buffer(6), function_constant(has_mask)]],
const device T* sinks [[buffer(7), function_constant(has_sinks)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
@@ -169,7 +171,7 @@ template <
VBlockLoader loader_v(
V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
TransformScale<T> ts(static_cast<T>(params->scale * 1.44269504089));
TransformScale<T> ts(static_cast<T>(params->scale * M_LOG2E_F));
// Prepare MMA tiles
constexpr short kFragSize = 8; // MMAFrag size
@@ -232,6 +234,14 @@ template <
max_score[i] = Limits<AccumType>::finite_min;
}
if (has_sinks) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = M_LOG2E_F * static_cast<AccumType>(sinks[tidl.y]);
sum_score[i] = 1;
}
}
int kb_lim = params->NK;
if (do_causal) {
@@ -350,7 +360,7 @@ template <
Stile.frag_at(i, j)[jj] =
mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf;
} else {
Stile.frag_at(i, j)[jj] += 1.44269504089 * selem_t(mfrag[jj]);
Stile.frag_at(i, j)[jj] += M_LOG2E_F * selem_t(mfrag[jj]);
}
}
}

View File

@@ -21,8 +21,9 @@ void sdpa_full_self_attention_metal(
const array& v,
const float scale,
array& o,
bool do_causal_ = false,
const std::optional<array>& mask = std::nullopt) {
bool do_causal_,
const std::optional<array>& mask,
const std::optional<array>& sinks) {
using namespace mlx::steel;
int wm = 4;
@@ -42,35 +43,49 @@ void sdpa_full_self_attention_metal(
const bool align_Q = (qL % bq) == 0;
const bool align_K = (kL % bk) == 0;
const bool has_mask = !!mask;
const bool has_mask = mask.has_value();
const bool do_causal = do_causal_;
const bool has_sinks = sinks.has_value();
metal::MTLFCList func_consts = {
{&align_Q, MTL::DataType::DataTypeBool, 200},
{&align_K, MTL::DataType::DataTypeBool, 201},
{&has_mask, MTL::DataType::DataTypeBool, 300},
{&do_causal, MTL::DataType::DataTypeBool, 301}};
{&do_causal, MTL::DataType::DataTypeBool, 301},
{&has_sinks, MTL::DataType::DataTypeBool, 302}};
std::ostringstream kname;
// clang-format off
kname << "steel_attention_"
<< type_to_name(q)
<< "_bq" << bq
<< "_bk" << bk
<< "_bd" << bd
<< "_wm" << wm
<< "_wn" << wn
<< "_mask" << (type_to_name(has_mask ? *mask : q)); // clang-format on
std::string base_name;
concatenate(
base_name,
"steel_attention_",
type_to_name(q),
"_bq",
bq,
"_bk",
bk,
"_bd",
bd,
"_wm",
wm,
"_wn",
wn,
"_mask",
type_to_name(has_mask ? *mask : q));
std::string base_name = kname.str();
// clang-format off
kname << "_align_Q_" << (align_Q ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n')
<< "_has_mask_" << (has_mask ? 't' : 'n')
<< "_do_causal_" << (do_causal ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
std::string hash_name;
concatenate(
hash_name,
base_name,
"_align_Q_",
(align_Q ? 't' : 'n'),
"_align_K_",
(align_K ? 't' : 'n'),
"_has_mask_",
(has_mask ? 't' : 'n'),
"_do_causal_",
(do_causal ? 't' : 'n'),
"_has_sinks_",
(has_sinks ? 't' : 'n'));
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
@@ -114,8 +129,8 @@ void sdpa_full_self_attention_metal(
compute_encoder.set_output_array(o, 3);
compute_encoder.set_bytes(params, 4);
if (mask) {
auto m = *mask;
if (has_mask) {
auto& m = *mask;
AttnMaskParams mask_params{/* int64_t M_strides[3] = */ {
m.strides(0), m.strides(1), m.strides(2)}};
@@ -123,6 +138,9 @@ void sdpa_full_self_attention_metal(
compute_encoder.set_bytes(mask_params, 5);
compute_encoder.set_input_array(m, 6);
}
if (has_sinks) {
compute_encoder.set_input_array(*sinks, 7);
}
MTL::Size grid_dims = MTL::Size(NQ, H, B);
MTL::Size group_dims = MTL::Size(32, wm, wn);
@@ -139,7 +157,8 @@ void sdpa_vector(
array& out,
float scale,
bool do_causal,
const std::optional<array>& mask) {
const std::optional<array>& mask,
const std::optional<array>& sinks) {
// Set the kernel name
std::string kname;
kname.reserve(64);
@@ -153,30 +172,32 @@ void sdpa_vector(
// Compute the necessary sizes
int gqa_factor = q.shape(1) / k.shape(1);
int N = k.shape(2);
int B = q.shape(0) * q.shape(1);
size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1);
size_t k_seq_stride = k.strides()[2];
size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1);
size_t v_seq_stride = v.strides()[2];
MTL::Size group_dims(1024, 1, 1);
MTL::Size grid_dims(B, q.shape(2), 1);
MTL::Size grid_dims(q.shape(0) * q.shape(1), q.shape(2), 1);
bool has_mask = mask.has_value();
bool bool_mask = has_mask && (*mask).dtype() == bool_;
bool float_mask = has_mask && !bool_mask;
bool query_transposed = !q.flags().row_contiguous;
bool has_sinks = sinks.has_value();
metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20},
{&query_transposed, MTL::DataType::DataTypeBool, 21},
{&do_causal, MTL::DataType::DataTypeBool, 22},
{&bool_mask, MTL::DataType::DataTypeBool, 23},
{&float_mask, MTL::DataType::DataTypeBool, 24},
{&has_sinks, MTL::DataType::DataTypeBool, 25},
};
std::string hash_name = kname;
hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
hash_name += query_transposed ? "_qt" : "_qnt";
hash_name += do_causal ? "_c" : "_nc";
hash_name += has_sinks ? "_sinks" : "_nosinks";
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -207,6 +228,10 @@ void sdpa_vector(
compute_encoder.set_bytes(q_seq_stride, 14);
compute_encoder.set_bytes(head_stride, 15);
}
if (has_sinks) {
compute_encoder.set_input_array(*sinks, 16);
compute_encoder.set_bytes(q.shape(1), 17);
}
// Launch
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
@@ -221,7 +246,8 @@ void sdpa_vector_2pass(
array& out,
float scale,
bool do_causal,
const std::optional<array>& mask) {
const std::optional<array>& mask,
const std::optional<array>& sinks) {
// Set the kernel name
std::string kname;
kname.reserve(64);
@@ -267,17 +293,20 @@ void sdpa_vector_2pass(
bool bool_mask = has_mask && (*mask).dtype() == bool_;
bool float_mask = has_mask && !bool_mask;
bool query_transposed = !q.flags().row_contiguous;
bool has_sinks = sinks.has_value();
metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20},
{&query_transposed, MTL::DataType::DataTypeBool, 21},
{&do_causal, MTL::DataType::DataTypeBool, 22},
{&bool_mask, MTL::DataType::DataTypeBool, 23},
{&float_mask, MTL::DataType::DataTypeBool, 24},
{&has_sinks, MTL::DataType::DataTypeBool, 25},
};
std::string hash_name = kname;
hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
hash_name += query_transposed ? "_qt" : "_qnt";
hash_name += do_causal ? "_c" : "_nc";
hash_name += has_sinks ? "_sinks" : "_nosinks";
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -310,6 +339,10 @@ void sdpa_vector_2pass(
compute_encoder.set_bytes(q_seq_stride, 16);
compute_encoder.set_bytes(head_stride, 17);
}
if (has_sinks) {
compute_encoder.set_input_array(*sinks, 18);
compute_encoder.set_bytes(q.shape(1), 19);
}
// Launch
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
@@ -411,6 +444,12 @@ void ScaledDotProductAttention::eval_gpu(
return arr.strides(-1) == 1;
};
std::optional<array> sinks = std::nullopt;
if (has_sinks_) {
sinks = copy_unless(is_matrix_contiguous, inputs.back());
}
bool has_arr_mask = inputs.size() > (3 + has_sinks_);
// We are in vector mode ie single query
if (q_pre.shape(2) <= 8) {
auto q_copy_unless = [](const array& arr) {
@@ -462,7 +501,7 @@ void ScaledDotProductAttention::eval_gpu(
(strides[0] == strides[1] * shape[1]);
};
auto mask = inputs.size() > 3
auto mask = has_arr_mask
? std::optional<array>{copy_unless(mask_copy_unless, inputs[3])}
: std::nullopt;
@@ -473,9 +512,9 @@ void ScaledDotProductAttention::eval_gpu(
char devc = d.get_architecture().back();
if ((devc == 'd' && k.shape(2) >= 1024) ||
(k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) {
sdpa_vector_2pass(s, d, q, k, v, o, scale_, do_causal, mask);
sdpa_vector_2pass(s, d, q, k, v, o, scale_, do_causal, mask, sinks);
} else {
sdpa_vector(s, d, q, k, v, o, scale_, do_causal, mask);
sdpa_vector(s, d, q, k, v, o, scale_, do_causal, mask, sinks);
}
}
@@ -503,11 +542,12 @@ void ScaledDotProductAttention::eval_gpu(
{str_oB, str_oH, str_oL, str_oD},
flags);
auto mask = inputs.size() > 3
auto mask = has_arr_mask
? std::optional<array>{copy_unless(is_matrix_contiguous, inputs[3])}
: std::nullopt;
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o, do_causal_, mask);
sdpa_full_self_attention_metal(
s, d, q, k, v, scale_, o, do_causal_, mask, sinks);
}
d.add_temporaries(std::move(copies), s.index);

View File

@@ -579,6 +579,7 @@ array scaled_dot_product_attention(
const float scale,
const std::string& mask_mode /* = "" */,
const std::vector<array>& mask_arrs /* = {} */,
const std::optional<array>& sinks /* = {} */,
StreamOrDevice s /* = {}*/) {
for (const auto& tensor : {queries, keys, values}) {
if (tensor.ndim() != 4) {
@@ -679,13 +680,20 @@ array scaled_dot_product_attention(
<< final_type << ".";
throw std::invalid_argument(msg.str());
}
bool has_sinks = sinks.has_value();
auto q = astype(queries, final_type, s);
auto k = astype(keys, final_type, s);
auto v = astype(values, final_type, s);
auto fallback = [scale, final_type, n_q_heads, n_kv_heads, do_causal, s](
const std::vector<array>& inputs) {
auto fallback = [scale,
final_type,
n_q_heads,
n_kv_heads,
do_causal,
has_sinks,
has_arr_mask,
s](const std::vector<array>& inputs) {
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
int n_repeats = n_q_heads / n_kv_heads;
int B = q.shape(0);
@@ -698,20 +706,22 @@ array scaled_dot_product_attention(
v = expand_dims(v, 2, s);
}
auto scores = matmul(q, swapaxes(k, -1, -2, s), s);
if (inputs.size() > 3 || do_causal) {
if (has_arr_mask || do_causal) {
// Mask must be broadcast-compatible with [B, n_q_heads, L_q, L_kv]
auto mask = inputs.back();
if (do_causal) {
int kL = k.shape(-2);
int qL = q.shape(-2);
int q_off = (kL - qL) < 0 ? 0 : (kL - qL);
auto q_idx = arange(q_off, q_off + qL, s);
auto k_idx = arange(0, kL, s);
q_idx = expand_dims(q_idx, 1, s);
k_idx = expand_dims(k_idx, 0, s);
mask = greater_equal(q_idx, k_idx, s);
}
auto make_or_fetch_mask = [&]() {
if (do_causal) {
int kL = k.shape(-2);
int qL = q.shape(-2);
int q_off = (kL - qL) < 0 ? 0 : (kL - qL);
auto q_idx = arange(q_off, q_off + qL, s);
auto k_idx = arange(0, kL, s);
q_idx = expand_dims(q_idx, 1, s);
k_idx = expand_dims(k_idx, 0, s);
return greater_equal(q_idx, k_idx, s);
}
return inputs[3];
};
auto mask = make_or_fetch_mask();
if (n_repeats > 1 && mask.ndim() >= 3) {
if (mask.shape(-3) == 1) {
@@ -730,7 +740,25 @@ array scaled_dot_product_attention(
scores = add(scores, mask, s);
}
}
if (has_sinks) {
auto sinks = inputs.back();
// scores has shape B N_q N_k L_q L_k
sinks = expand_dims(sinks, {0, 2, 3}, s);
if (scores.ndim() == 5) {
sinks = unflatten(sinks, 1, {n_kv_heads, n_repeats}, s);
}
auto bsx_shape = scores.shape();
bsx_shape.back() = 1;
scores = concatenate({broadcast_to(sinks, bsx_shape, s), scores}, -1, s);
}
scores = softmax(scores, std::vector<int>{-1}, true, s);
if (has_sinks) {
// Slice off scores
auto start = Shape(scores.ndim(), 0);
start.back() = 1;
auto stop = scores.shape();
scores = slice(scores, std::move(start), std::move(stop), s);
}
auto out = matmul(scores, v, s);
if (n_repeats > 1) {
out = flatten(out, 1, 2, s);
@@ -746,7 +774,7 @@ array scaled_dot_product_attention(
has_bool_mask = mask_arr.dtype() == bool_;
if (promote_types(mask_arr.dtype(), final_type) != final_type) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
msg << "[scaled_dot_product_attention] Mask type must promote to output type "
<< final_type << ".";
throw std::invalid_argument(msg.str());
} else if (!has_bool_mask) {
@@ -757,6 +785,22 @@ array scaled_dot_product_attention(
mask_shape.back() = keys.shape(-2);
inputs.push_back(broadcast_to(mask_arr, mask_shape, stream));
}
if (has_sinks) {
if (promote_types(sinks->dtype(), final_type) != final_type) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Type of sinks must promote to output type "
<< final_type << ".";
throw std::invalid_argument(msg.str());
}
if (sinks->ndim() != 1 || sinks->shape(0) != n_q_heads) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Received invalid shape for sinks "
<< sinks->shape() << ".";
throw std::invalid_argument(msg.str());
}
inputs.push_back(astype(*sinks, final_type, stream));
}
if (!ScaledDotProductAttention::use_fallback(
q, k, v, has_mask, has_arr_mask, do_causal, stream)) {
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
@@ -764,7 +808,7 @@ array scaled_dot_product_attention(
std::move(out_shape),
final_type,
std::make_shared<ScaledDotProductAttention>(
stream, fallback, scale, do_causal),
stream, fallback, scale, do_causal, has_sinks),
std::move(inputs));
}
return fallback(std::move(inputs))[0];
@@ -773,7 +817,8 @@ array scaled_dot_product_attention(
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
const ScaledDotProductAttention& a_other =
static_cast<const ScaledDotProductAttention&>(other);
return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_;
return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_ &&
has_sinks_ == a_other.has_sinks_;
}
bool Quantize::is_equivalent(const Primitive& other) const {

View File

@@ -50,6 +50,7 @@ array scaled_dot_product_attention(
const float scale,
const std::string& mask_mode = "",
const std::vector<array>& mask_arrs = {},
const std::optional<array>& sinks = {},
StreamOrDevice s = {});
using TemplateArg = std::variant<int, bool, Dtype>;

View File

@@ -208,9 +208,13 @@ class ScaledDotProductAttention : public Custom {
explicit ScaledDotProductAttention(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
const float scale,
const bool do_causal)
: Custom(stream, fallback), scale_(scale), do_causal_(do_causal) {}
float scale,
bool do_causal,
bool has_sinks)
: Custom(stream, fallback),
scale_(scale),
do_causal_(do_causal),
has_sinks_(has_sinks) {}
static bool use_fallback(
const array& q,
@@ -237,12 +241,13 @@ class ScaledDotProductAttention : public Custom {
DEFINE_NAME(ScaledDotProductAttention);
DEFINE_INPUT_OUTPUT_SHAPE()
auto state() const {
return std::make_tuple(nullptr, scale_, do_causal_);
return std::make_tuple(nullptr, scale_, do_causal_, has_sinks_);
}
private:
float scale_;
bool do_causal_;
bool has_sinks_;
};
class Quantize : public Custom {