mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
add sdpa with sinks
This commit is contained in:
81
mlx/fast.cpp
81
mlx/fast.cpp
@@ -579,6 +579,7 @@ array scaled_dot_product_attention(
|
||||
const float scale,
|
||||
const std::string& mask_mode /* = "" */,
|
||||
const std::vector<array>& mask_arrs /* = {} */,
|
||||
const std::optional<array>& sinks /* = {} */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
for (const auto& tensor : {queries, keys, values}) {
|
||||
if (tensor.ndim() != 4) {
|
||||
@@ -679,13 +680,20 @@ array scaled_dot_product_attention(
|
||||
<< final_type << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
bool has_sinks = sinks.has_value();
|
||||
|
||||
auto q = astype(queries, final_type, s);
|
||||
auto k = astype(keys, final_type, s);
|
||||
auto v = astype(values, final_type, s);
|
||||
|
||||
auto fallback = [scale, final_type, n_q_heads, n_kv_heads, do_causal, s](
|
||||
const std::vector<array>& inputs) {
|
||||
auto fallback = [scale,
|
||||
final_type,
|
||||
n_q_heads,
|
||||
n_kv_heads,
|
||||
do_causal,
|
||||
has_sinks,
|
||||
has_arr_mask,
|
||||
s](const std::vector<array>& inputs) {
|
||||
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
|
||||
int n_repeats = n_q_heads / n_kv_heads;
|
||||
int B = q.shape(0);
|
||||
@@ -698,20 +706,22 @@ array scaled_dot_product_attention(
|
||||
v = expand_dims(v, 2, s);
|
||||
}
|
||||
auto scores = matmul(q, swapaxes(k, -1, -2, s), s);
|
||||
if (inputs.size() > 3 || do_causal) {
|
||||
if (has_arr_mask || do_causal) {
|
||||
// Mask must be broadcast-compatible with [B, n_q_heads, L_q, L_kv]
|
||||
auto mask = inputs.back();
|
||||
|
||||
if (do_causal) {
|
||||
int kL = k.shape(-2);
|
||||
int qL = q.shape(-2);
|
||||
int q_off = (kL - qL) < 0 ? 0 : (kL - qL);
|
||||
auto q_idx = arange(q_off, q_off + qL, s);
|
||||
auto k_idx = arange(0, kL, s);
|
||||
q_idx = expand_dims(q_idx, 1, s);
|
||||
k_idx = expand_dims(k_idx, 0, s);
|
||||
mask = greater_equal(q_idx, k_idx, s);
|
||||
}
|
||||
auto make_or_fetch_mask = [&]() {
|
||||
if (do_causal) {
|
||||
int kL = k.shape(-2);
|
||||
int qL = q.shape(-2);
|
||||
int q_off = (kL - qL) < 0 ? 0 : (kL - qL);
|
||||
auto q_idx = arange(q_off, q_off + qL, s);
|
||||
auto k_idx = arange(0, kL, s);
|
||||
q_idx = expand_dims(q_idx, 1, s);
|
||||
k_idx = expand_dims(k_idx, 0, s);
|
||||
return greater_equal(q_idx, k_idx, s);
|
||||
}
|
||||
return inputs[3];
|
||||
};
|
||||
auto mask = make_or_fetch_mask();
|
||||
|
||||
if (n_repeats > 1 && mask.ndim() >= 3) {
|
||||
if (mask.shape(-3) == 1) {
|
||||
@@ -730,7 +740,25 @@ array scaled_dot_product_attention(
|
||||
scores = add(scores, mask, s);
|
||||
}
|
||||
}
|
||||
if (has_sinks) {
|
||||
auto sinks = inputs.back();
|
||||
// scores has shape B N_q N_k L_q L_k
|
||||
sinks = expand_dims(sinks, {0, 2, 3}, s);
|
||||
if (scores.ndim() == 5) {
|
||||
sinks = unflatten(sinks, 1, {n_kv_heads, n_repeats}, s);
|
||||
}
|
||||
auto bsx_shape = scores.shape();
|
||||
bsx_shape.back() = 1;
|
||||
scores = concatenate({broadcast_to(sinks, bsx_shape, s), scores}, -1, s);
|
||||
}
|
||||
scores = softmax(scores, std::vector<int>{-1}, true, s);
|
||||
if (has_sinks) {
|
||||
// Slice off scores
|
||||
auto start = Shape(scores.ndim(), 0);
|
||||
start.back() = 1;
|
||||
auto stop = scores.shape();
|
||||
scores = slice(scores, std::move(start), std::move(stop), s);
|
||||
}
|
||||
auto out = matmul(scores, v, s);
|
||||
if (n_repeats > 1) {
|
||||
out = flatten(out, 1, 2, s);
|
||||
@@ -746,7 +774,7 @@ array scaled_dot_product_attention(
|
||||
has_bool_mask = mask_arr.dtype() == bool_;
|
||||
if (promote_types(mask_arr.dtype(), final_type) != final_type) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
|
||||
msg << "[scaled_dot_product_attention] Mask type must promote to output type "
|
||||
<< final_type << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
} else if (!has_bool_mask) {
|
||||
@@ -757,6 +785,22 @@ array scaled_dot_product_attention(
|
||||
mask_shape.back() = keys.shape(-2);
|
||||
inputs.push_back(broadcast_to(mask_arr, mask_shape, stream));
|
||||
}
|
||||
if (has_sinks) {
|
||||
if (promote_types(sinks->dtype(), final_type) != final_type) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Type of sinks must promote to output type "
|
||||
<< final_type << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (sinks->ndim() != 1 || sinks->shape(0) != n_q_heads) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Received invalid shape for sinks "
|
||||
<< sinks->shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
inputs.push_back(astype(*sinks, final_type, stream));
|
||||
}
|
||||
|
||||
if (!ScaledDotProductAttention::use_fallback(
|
||||
q, k, v, has_mask, has_arr_mask, do_causal, stream)) {
|
||||
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
||||
@@ -764,7 +808,7 @@ array scaled_dot_product_attention(
|
||||
std::move(out_shape),
|
||||
final_type,
|
||||
std::make_shared<ScaledDotProductAttention>(
|
||||
stream, fallback, scale, do_causal),
|
||||
stream, fallback, scale, do_causal, has_sinks),
|
||||
std::move(inputs));
|
||||
}
|
||||
return fallback(std::move(inputs))[0];
|
||||
@@ -773,7 +817,8 @@ array scaled_dot_product_attention(
|
||||
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
|
||||
const ScaledDotProductAttention& a_other =
|
||||
static_cast<const ScaledDotProductAttention&>(other);
|
||||
return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_;
|
||||
return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_ &&
|
||||
has_sinks_ == a_other.has_sinks_;
|
||||
}
|
||||
|
||||
bool Quantize::is_equivalent(const Primitive& other) const {
|
||||
|
||||
Reference in New Issue
Block a user