fix per-example mask + docs in sdpa (#1574)

This commit is contained in:
Awni Hannun
2024-11-08 11:51:15 -08:00
committed by GitHub
parent 9f0d5c12fc
commit 91c0277356
3 changed files with 42 additions and 11 deletions

View File

@@ -533,6 +533,12 @@ array scaled_dot_product_attention(
throw std::invalid_argument(msg.str());
}
}
if (mask and (*mask).ndim() > 4) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] the mask with shape "
<< (*mask).shape() << " expected to have at most rank 4";
throw std::invalid_argument(msg.str());
}
const size_t batch_dim = queries.shape(0);
for (const auto& tensor : {keys, values}) {
@@ -599,8 +605,7 @@ array scaled_dot_product_attention(
threshold = std::max(1, memory_efficient_threshold.value());
}
bool needs_mask = mask.has_value();
auto fallback = [scale, needs_mask, final_type, n_q_heads, n_kv_heads, &s](
auto fallback = [scale, final_type, n_q_heads, n_kv_heads, &s](
const std::vector<array>& inputs) {
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
int n_repeats = n_q_heads / n_kv_heads;
@@ -614,8 +619,12 @@ array scaled_dot_product_attention(
v = expand_dims(v, 2, s);
}
auto scores = matmul(q, swapaxes(k, -1, -2, s), s);
if (needs_mask) {
scores = add(scores, inputs[3], s);
if (inputs.size() > 3) {
auto mask_shape = inputs[0].shape();
mask_shape.back() = inputs[1].shape(-2);
auto mask = reshape(
broadcast_to(inputs[3], std::move(mask_shape), s), scores.shape(), s);
scores = add(scores, mask, s);
}
scores = softmax(scores, std::vector<int>{-1}, true, s);
auto out = matmul(scores, v, s);