mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-27 03:21:19 +08:00
fix per-example mask + docs in sdpa (#1574)
This commit is contained in:
parent
9f0d5c12fc
commit
91c0277356
17
mlx/fast.cpp
17
mlx/fast.cpp
@ -533,6 +533,12 @@ array scaled_dot_product_attention(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (mask and (*mask).ndim() > 4) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[scaled_dot_product_attention] the mask with shape "
|
||||||
|
<< (*mask).shape() << " expected to have at most rank 4";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
const size_t batch_dim = queries.shape(0);
|
const size_t batch_dim = queries.shape(0);
|
||||||
for (const auto& tensor : {keys, values}) {
|
for (const auto& tensor : {keys, values}) {
|
||||||
@ -599,8 +605,7 @@ array scaled_dot_product_attention(
|
|||||||
threshold = std::max(1, memory_efficient_threshold.value());
|
threshold = std::max(1, memory_efficient_threshold.value());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool needs_mask = mask.has_value();
|
auto fallback = [scale, final_type, n_q_heads, n_kv_heads, &s](
|
||||||
auto fallback = [scale, needs_mask, final_type, n_q_heads, n_kv_heads, &s](
|
|
||||||
const std::vector<array>& inputs) {
|
const std::vector<array>& inputs) {
|
||||||
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
|
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
|
||||||
int n_repeats = n_q_heads / n_kv_heads;
|
int n_repeats = n_q_heads / n_kv_heads;
|
||||||
@ -614,8 +619,12 @@ array scaled_dot_product_attention(
|
|||||||
v = expand_dims(v, 2, s);
|
v = expand_dims(v, 2, s);
|
||||||
}
|
}
|
||||||
auto scores = matmul(q, swapaxes(k, -1, -2, s), s);
|
auto scores = matmul(q, swapaxes(k, -1, -2, s), s);
|
||||||
if (needs_mask) {
|
if (inputs.size() > 3) {
|
||||||
scores = add(scores, inputs[3], s);
|
auto mask_shape = inputs[0].shape();
|
||||||
|
mask_shape.back() = inputs[1].shape(-2);
|
||||||
|
auto mask = reshape(
|
||||||
|
broadcast_to(inputs[3], std::move(mask_shape), s), scores.shape(), s);
|
||||||
|
scores = add(scores, mask, s);
|
||||||
}
|
}
|
||||||
scores = softmax(scores, std::vector<int>{-1}, true, s);
|
scores = softmax(scores, std::vector<int>{-1}, true, s);
|
||||||
auto out = matmul(scores, v, s);
|
auto out = matmul(scores, v, s);
|
||||||
|
@ -140,12 +140,23 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
Note: For Grouped Query Attention and Multi-Query Attention, the ``k``
|
Note: For Grouped Query Attention and Multi-Query Attention, the ``k``
|
||||||
and ``v`` inputs should not be pre-tiled to match ``q``.
|
and ``v`` inputs should not be pre-tiled to match ``q``.
|
||||||
|
|
||||||
|
In the following the dimensions are given by:
|
||||||
|
|
||||||
|
* ``B``: The batch size.
|
||||||
|
* ``N_q``: The number of query heads.
|
||||||
|
* ``N_kv``: The number of key and value heads.
|
||||||
|
* ``T_q``: The number of queries per example.
|
||||||
|
* ``T_kv``: The number of keys and values per example.
|
||||||
|
* ``D``: The per-head dimension.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
q (array): Input query array.
|
q (array): Queries with shape ``[B, N_q, T_q, D]``.
|
||||||
k (array): Input keys array.
|
k (array): Keys with shape ``[B, N_kv, T_kv, D]``.
|
||||||
v (array): Input values array.
|
v (array): Values with shape ``[B, N_kv, T_kv, D]``.
|
||||||
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
|
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
|
||||||
mask (array, optional): An additive mask to apply to the query-key scores.
|
mask (array, optional): An additive mask to apply to the query-key
|
||||||
|
scores. The mask can have at most 4 dimensions and must be
|
||||||
|
broadcast-compatible with the shape ``[B, N, T_q, T_kv]``.
|
||||||
Returns:
|
Returns:
|
||||||
array: The output array.
|
array: The output array.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
@ -7,14 +7,16 @@ import numpy as np
|
|||||||
|
|
||||||
|
|
||||||
# SDPA for MHA (n_heads == n_kv_heads)
|
# SDPA for MHA (n_heads == n_kv_heads)
|
||||||
def mlx_primitives_sdpa(q, k, v, scale):
|
def mlx_primitives_sdpa(q, k, v, scale, mask=None):
|
||||||
p = (q * scale) @ k.transpose(0, 1, 3, 2)
|
p = (q * scale) @ k.transpose(0, 1, 3, 2)
|
||||||
|
if mask is not None:
|
||||||
|
p += mask
|
||||||
scores = mx.softmax(p.astype(mx.float32), axis=-1).astype(p.dtype)
|
scores = mx.softmax(p.astype(mx.float32), axis=-1).astype(p.dtype)
|
||||||
return scores @ v
|
return scores @ v
|
||||||
|
|
||||||
|
|
||||||
# SDPA for GQA (n_heads > n_kv_heads, n_kv_heads > 1, n_heads % n_kv_heads == 0)
|
# SDPA for GQA (n_heads > n_kv_heads, n_kv_heads > 1, n_heads % n_kv_heads == 0)
|
||||||
def mlx_primitives_sdpa_with_gqa(q, k, v, scale):
|
def mlx_primitives_sdpa_with_gqa(q, k, v, scale, mask=None):
|
||||||
n_repeats = q.shape[1] // k.shape[1]
|
n_repeats = q.shape[1] // k.shape[1]
|
||||||
|
|
||||||
# borrowing kv cache tiling from mlx-examples/llms/mistral/mistral.py
|
# borrowing kv cache tiling from mlx-examples/llms/mistral/mistral.py
|
||||||
@ -28,7 +30,7 @@ def mlx_primitives_sdpa_with_gqa(q, k, v, scale):
|
|||||||
|
|
||||||
k, v = map(repeat, (k, v))
|
k, v = map(repeat, (k, v))
|
||||||
|
|
||||||
return mlx_primitives_sdpa(q, k, v, scale)
|
return mlx_primitives_sdpa(q, k, v, scale, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase):
|
class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase):
|
||||||
@ -176,6 +178,15 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
|||||||
y_hat = mx.fast.scaled_dot_product_attention(q, k, v[:, :, :32], scale=scale)
|
y_hat = mx.fast.scaled_dot_product_attention(q, k, v[:, :, :32], scale=scale)
|
||||||
self.assertTrue(mx.allclose(y, y_hat, atol=atol))
|
self.assertTrue(mx.allclose(y, y_hat, atol=atol))
|
||||||
|
|
||||||
|
# Test with per-example mask
|
||||||
|
q = mx.random.normal(shape=(2, 8, 4, 32))
|
||||||
|
k = mx.random.normal(shape=(2, 2, 8, 32))
|
||||||
|
v = mx.random.normal(shape=(2, 2, 8, 32))
|
||||||
|
mask = 10 * mx.random.normal(shape=(2, 1, 4, 8))
|
||||||
|
y = mlx_primitives_sdpa_with_gqa(q, k, v, scale, mask=mask)
|
||||||
|
y_hat = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||||
|
self.assertTrue(mx.allclose(y, y_hat, atol=atol))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main(failfast=True)
|
unittest.main(failfast=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user