mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:11:43 +08:00
Use expand_dims / unflatten / etc in more places (#1696)
* use expand_dims / unflatten in a couple more places * few more * few more * fix
This commit is contained in:
13
mlx/fast.cpp
13
mlx/fast.cpp
@@ -620,10 +620,15 @@ array scaled_dot_product_attention(
|
||||
}
|
||||
auto scores = matmul(q, swapaxes(k, -1, -2, s), s);
|
||||
if (inputs.size() > 3) {
|
||||
auto mask_shape = inputs[0].shape();
|
||||
mask_shape.back() = inputs[1].shape(-2);
|
||||
auto mask = reshape(
|
||||
broadcast_to(inputs[3], std::move(mask_shape), s), scores.shape(), s);
|
||||
// Mask must be broadcast-compatible with [B, n_q_heads, L_q, L_kv]
|
||||
auto mask = inputs[3];
|
||||
if (n_repeats > 1 && mask.ndim() >= 3) {
|
||||
if (mask.shape(-3) == 1) {
|
||||
mask = expand_dims(mask, -3, s);
|
||||
} else {
|
||||
mask = unflatten(mask, -3, {n_kv_heads, n_repeats}, s);
|
||||
}
|
||||
}
|
||||
scores = add(scores, mask, s);
|
||||
}
|
||||
scores = softmax(scores, std::vector<int>{-1}, true, s);
|
||||
|
@@ -542,6 +542,9 @@ array squeeze(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (axes.empty()) {
|
||||
return a;
|
||||
}
|
||||
std::set<int> unique_axes;
|
||||
for (auto ax : axes) {
|
||||
unique_axes.insert(ax < 0 ? ax + a.ndim() : ax);
|
||||
@@ -598,6 +601,9 @@ array expand_dims(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (axes.empty()) {
|
||||
return a;
|
||||
}
|
||||
{ // Check for repeats
|
||||
std::set<int> unique_axes(axes.begin(), axes.end());
|
||||
if (unique_axes.size() != axes.size()) {
|
||||
|
Reference in New Issue
Block a user