Flatten and unflatten (#1692)

* flatten and unflatten

* fix grad

* fix shape infer

* use squeeze + unsqueeze in get_item
This commit is contained in:
Awni Hannun
2024-12-11 21:51:37 -08:00
committed by GitHub
parent 0bf19037ca
commit 4e1e9520e1
19 changed files with 363 additions and 93 deletions

View File

@@ -614,7 +614,7 @@ array scaled_dot_product_attention(
auto k = inputs[1];
auto v = inputs[2];
if (n_repeats > 1) {
q = reshape(q, {B, n_kv_heads, n_repeats, L, -1}, s);
q = unflatten(q, 1, {n_kv_heads, n_repeats}, s);
k = expand_dims(k, 2, s);
v = expand_dims(v, 2, s);
}
@@ -629,7 +629,7 @@ array scaled_dot_product_attention(
scores = softmax(scores, std::vector<int>{-1}, true, s);
auto out = matmul(scores, v, s);
if (n_repeats > 1) {
out = reshape(out, {B, n_q_heads, L, -1}, s);
out = flatten(out, 1, 2, s);
}
return std::vector<array>{out};
};