Use expand_dims / unflatten / etc in more places (#1696)

* use expand_dims / unflatten in a couple more places

* few more

* few more

* fix
This commit is contained in:
Awni Hannun
2024-12-12 17:00:44 -08:00
committed by GitHub
parent 9111999af3
commit 50f3535693
3 changed files with 43 additions and 45 deletions

View File

@@ -620,10 +620,15 @@ array scaled_dot_product_attention(
}
auto scores = matmul(q, swapaxes(k, -1, -2, s), s);
if (inputs.size() > 3) {
auto mask_shape = inputs[0].shape();
mask_shape.back() = inputs[1].shape(-2);
auto mask = reshape(
broadcast_to(inputs[3], std::move(mask_shape), s), scores.shape(), s);
// Mask must be broadcast-compatible with [B, n_q_heads, L_q, L_kv]
auto mask = inputs[3];
if (n_repeats > 1 && mask.ndim() >= 3) {
if (mask.shape(-3) == 1) {
mask = expand_dims(mask, -3, s);
} else {
mask = unflatten(mask, -3, {n_kv_heads, n_repeats}, s);
}
}
scores = add(scores, mask, s);
}
scores = softmax(scores, std::vector<int>{-1}, true, s);

View File

@@ -542,6 +542,9 @@ array squeeze(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s /* = {} */) {
if (axes.empty()) {
return a;
}
std::set<int> unique_axes;
for (auto ax : axes) {
unique_axes.insert(ax < 0 ? ax + a.ndim() : ax);
@@ -598,6 +601,9 @@ array expand_dims(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s /* = {} */) {
if (axes.empty()) {
return a;
}
{ // Check for repeats
std::set<int> unique_axes(axes.begin(), axes.end());
if (unique_axes.size() != axes.size()) {