mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Flatten and unflatten (#1692)
* flatten and unflatten * fix grad * fix shape infer * use squeeze + unsqueeze in get_item
This commit is contained in:
@@ -614,7 +614,7 @@ array scaled_dot_product_attention(
|
||||
auto k = inputs[1];
|
||||
auto v = inputs[2];
|
||||
if (n_repeats > 1) {
|
||||
q = reshape(q, {B, n_kv_heads, n_repeats, L, -1}, s);
|
||||
q = unflatten(q, 1, {n_kv_heads, n_repeats}, s);
|
||||
k = expand_dims(k, 2, s);
|
||||
v = expand_dims(v, 2, s);
|
||||
}
|
||||
@@ -629,7 +629,7 @@ array scaled_dot_product_attention(
|
||||
scores = softmax(scores, std::vector<int>{-1}, true, s);
|
||||
auto out = matmul(scores, v, s);
|
||||
if (n_repeats > 1) {
|
||||
out = reshape(out, {B, n_q_heads, L, -1}, s);
|
||||
out = flatten(out, 1, 2, s);
|
||||
}
|
||||
return std::vector<array>{out};
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user