mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 10:48:09 +08:00
fix per-example mask + docs in sdpa (#1574)
This commit is contained in:
17
mlx/fast.cpp
17
mlx/fast.cpp
@@ -533,6 +533,12 @@ array scaled_dot_product_attention(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
if (mask and (*mask).ndim() > 4) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] the mask with shape "
|
||||
<< (*mask).shape() << " expected to have at most rank 4";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
const size_t batch_dim = queries.shape(0);
|
||||
for (const auto& tensor : {keys, values}) {
|
||||
@@ -599,8 +605,7 @@ array scaled_dot_product_attention(
|
||||
threshold = std::max(1, memory_efficient_threshold.value());
|
||||
}
|
||||
|
||||
bool needs_mask = mask.has_value();
|
||||
auto fallback = [scale, needs_mask, final_type, n_q_heads, n_kv_heads, &s](
|
||||
auto fallback = [scale, final_type, n_q_heads, n_kv_heads, &s](
|
||||
const std::vector<array>& inputs) {
|
||||
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
|
||||
int n_repeats = n_q_heads / n_kv_heads;
|
||||
@@ -614,8 +619,12 @@ array scaled_dot_product_attention(
|
||||
v = expand_dims(v, 2, s);
|
||||
}
|
||||
auto scores = matmul(q, swapaxes(k, -1, -2, s), s);
|
||||
if (needs_mask) {
|
||||
scores = add(scores, inputs[3], s);
|
||||
if (inputs.size() > 3) {
|
||||
auto mask_shape = inputs[0].shape();
|
||||
mask_shape.back() = inputs[1].shape(-2);
|
||||
auto mask = reshape(
|
||||
broadcast_to(inputs[3], std::move(mask_shape), s), scores.shape(), s);
|
||||
scores = add(scores, mask, s);
|
||||
}
|
||||
scores = softmax(scores, std::vector<int>{-1}, true, s);
|
||||
auto out = matmul(scores, v, s);
|
||||
|
Reference in New Issue
Block a user