mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Use expand_dims / unflatten / etc in more places (#1696)
* use expand_dims / unflatten in a couple more places * few more * few more * fix
This commit is contained in:
13
mlx/fast.cpp
13
mlx/fast.cpp
@@ -620,10 +620,15 @@ array scaled_dot_product_attention(
|
||||
}
|
||||
auto scores = matmul(q, swapaxes(k, -1, -2, s), s);
|
||||
if (inputs.size() > 3) {
|
||||
auto mask_shape = inputs[0].shape();
|
||||
mask_shape.back() = inputs[1].shape(-2);
|
||||
auto mask = reshape(
|
||||
broadcast_to(inputs[3], std::move(mask_shape), s), scores.shape(), s);
|
||||
// Mask must be broadcast-compatible with [B, n_q_heads, L_q, L_kv]
|
||||
auto mask = inputs[3];
|
||||
if (n_repeats > 1 && mask.ndim() >= 3) {
|
||||
if (mask.shape(-3) == 1) {
|
||||
mask = expand_dims(mask, -3, s);
|
||||
} else {
|
||||
mask = unflatten(mask, -3, {n_kv_heads, n_repeats}, s);
|
||||
}
|
||||
}
|
||||
scores = add(scores, mask, s);
|
||||
}
|
||||
scores = softmax(scores, std::vector<int>{-1}, true, s);
|
||||
|
||||
Reference in New Issue
Block a user