add sdpa with sinks

This commit is contained in:
Awni Hannun
2025-08-31 10:59:50 -07:00
parent dde3682b69
commit 3ca3ab9dcd
9 changed files with 298 additions and 96 deletions

View File

@@ -579,6 +579,7 @@ array scaled_dot_product_attention(
const float scale,
const std::string& mask_mode /* = "" */,
const std::vector<array>& mask_arrs /* = {} */,
const std::optional<array>& sinks /* = {} */,
StreamOrDevice s /* = {}*/) {
for (const auto& tensor : {queries, keys, values}) {
if (tensor.ndim() != 4) {
@@ -679,13 +680,20 @@ array scaled_dot_product_attention(
<< final_type << ".";
throw std::invalid_argument(msg.str());
}
bool has_sinks = sinks.has_value();
auto q = astype(queries, final_type, s);
auto k = astype(keys, final_type, s);
auto v = astype(values, final_type, s);
auto fallback = [scale, final_type, n_q_heads, n_kv_heads, do_causal, s](
const std::vector<array>& inputs) {
auto fallback = [scale,
final_type,
n_q_heads,
n_kv_heads,
do_causal,
has_sinks,
has_arr_mask,
s](const std::vector<array>& inputs) {
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
int n_repeats = n_q_heads / n_kv_heads;
int B = q.shape(0);
@@ -698,20 +706,22 @@ array scaled_dot_product_attention(
v = expand_dims(v, 2, s);
}
auto scores = matmul(q, swapaxes(k, -1, -2, s), s);
if (inputs.size() > 3 || do_causal) {
if (has_arr_mask || do_causal) {
// Mask must be broadcast-compatible with [B, n_q_heads, L_q, L_kv]
auto mask = inputs.back();
if (do_causal) {
int kL = k.shape(-2);
int qL = q.shape(-2);
int q_off = (kL - qL) < 0 ? 0 : (kL - qL);
auto q_idx = arange(q_off, q_off + qL, s);
auto k_idx = arange(0, kL, s);
q_idx = expand_dims(q_idx, 1, s);
k_idx = expand_dims(k_idx, 0, s);
mask = greater_equal(q_idx, k_idx, s);
}
auto make_or_fetch_mask = [&]() {
if (do_causal) {
int kL = k.shape(-2);
int qL = q.shape(-2);
int q_off = (kL - qL) < 0 ? 0 : (kL - qL);
auto q_idx = arange(q_off, q_off + qL, s);
auto k_idx = arange(0, kL, s);
q_idx = expand_dims(q_idx, 1, s);
k_idx = expand_dims(k_idx, 0, s);
return greater_equal(q_idx, k_idx, s);
}
return inputs[3];
};
auto mask = make_or_fetch_mask();
if (n_repeats > 1 && mask.ndim() >= 3) {
if (mask.shape(-3) == 1) {
@@ -730,7 +740,25 @@ array scaled_dot_product_attention(
scores = add(scores, mask, s);
}
}
if (has_sinks) {
auto sinks = inputs.back();
// scores has shape B N_q N_k L_q L_k
sinks = expand_dims(sinks, {0, 2, 3}, s);
if (scores.ndim() == 5) {
sinks = unflatten(sinks, 1, {n_kv_heads, n_repeats}, s);
}
auto bsx_shape = scores.shape();
bsx_shape.back() = 1;
scores = concatenate({broadcast_to(sinks, bsx_shape, s), scores}, -1, s);
}
scores = softmax(scores, std::vector<int>{-1}, true, s);
if (has_sinks) {
// Slice off scores
auto start = Shape(scores.ndim(), 0);
start.back() = 1;
auto stop = scores.shape();
scores = slice(scores, std::move(start), std::move(stop), s);
}
auto out = matmul(scores, v, s);
if (n_repeats > 1) {
out = flatten(out, 1, 2, s);
@@ -746,7 +774,7 @@ array scaled_dot_product_attention(
has_bool_mask = mask_arr.dtype() == bool_;
if (promote_types(mask_arr.dtype(), final_type) != final_type) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
msg << "[scaled_dot_product_attention] Mask type must promote to output type "
<< final_type << ".";
throw std::invalid_argument(msg.str());
} else if (!has_bool_mask) {
@@ -757,6 +785,22 @@ array scaled_dot_product_attention(
mask_shape.back() = keys.shape(-2);
inputs.push_back(broadcast_to(mask_arr, mask_shape, stream));
}
if (has_sinks) {
if (promote_types(sinks->dtype(), final_type) != final_type) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Type of sinks must promote to output type "
<< final_type << ".";
throw std::invalid_argument(msg.str());
}
if (sinks->ndim() != 1 || sinks->shape(0) != n_q_heads) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Received invalid shape for sinks "
<< sinks->shape() << ".";
throw std::invalid_argument(msg.str());
}
inputs.push_back(astype(*sinks, final_type, stream));
}
if (!ScaledDotProductAttention::use_fallback(
q, k, v, has_mask, has_arr_mask, do_causal, stream)) {
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
@@ -764,7 +808,7 @@ array scaled_dot_product_attention(
std::move(out_shape),
final_type,
std::make_shared<ScaledDotProductAttention>(
stream, fallback, scale, do_causal),
stream, fallback, scale, do_causal, has_sinks),
std::move(inputs));
}
return fallback(std::move(inputs))[0];
@@ -773,7 +817,8 @@ array scaled_dot_product_attention(
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
const ScaledDotProductAttention& a_other =
static_cast<const ScaledDotProductAttention&>(other);
return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_;
return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_ &&
has_sinks_ == a_other.has_sinks_;
}
bool Quantize::is_equivalent(const Primitive& other) const {