mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 10:48:09 +08:00
Add sdpa with sinks (#2558)
* add sdpa with sinks * fix 2 pass * fix matrix sdpa * fix perf regression * add to cuda (#2580)
This commit is contained in:
@@ -196,6 +196,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
const mx::array& values,
|
||||
const float scale,
|
||||
const std::variant<std::monostate, std::string, mx::array>& mask,
|
||||
const std::optional<mx::array>& sinks,
|
||||
mx::StreamOrDevice s) {
|
||||
bool has_mask = !std::holds_alternative<std::monostate>(mask);
|
||||
bool has_str_mask =
|
||||
@@ -212,16 +213,16 @@ void init_fast(nb::module_& parent_module) {
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
return mx::fast::scaled_dot_product_attention(
|
||||
queries, keys, values, scale, mask_str, {}, s);
|
||||
queries, keys, values, scale, mask_str, {}, sinks, s);
|
||||
} else {
|
||||
auto mask_arr = std::get<mx::array>(mask);
|
||||
return mx::fast::scaled_dot_product_attention(
|
||||
queries, keys, values, scale, "", {mask_arr}, s);
|
||||
queries, keys, values, scale, "", {mask_arr}, sinks, s);
|
||||
}
|
||||
|
||||
} else {
|
||||
return mx::fast::scaled_dot_product_attention(
|
||||
queries, keys, values, scale, "", {}, s);
|
||||
queries, keys, values, scale, "", {}, sinks, s);
|
||||
}
|
||||
},
|
||||
"q"_a,
|
||||
@@ -230,9 +231,10 @@ void init_fast(nb::module_& parent_module) {
|
||||
nb::kw_only(),
|
||||
"scale"_a,
|
||||
"mask"_a = nb::none(),
|
||||
"sinks"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, str, array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, str, array] = None, sinks: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
|
||||
|
||||
@@ -262,14 +264,17 @@ void init_fast(nb::module_& parent_module) {
|
||||
q (array): Queries with shape ``[B, N_q, T_q, D]``.
|
||||
k (array): Keys with shape ``[B, N_kv, T_kv, D]``.
|
||||
v (array): Values with shape ``[B, N_kv, T_kv, D]``.
|
||||
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
|
||||
mask (Union[None, str, array], optional): The mask to apply to the
|
||||
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``).
|
||||
mask (str or array, optional): The mask to apply to the
|
||||
query-key scores. The mask can be an array or a string indicating
|
||||
the mask type. The only supported string type is ``"causal"``. If
|
||||
the mask is an array it can be a boolean or additive mask. The mask
|
||||
can have at most 4 dimensions and must be broadcast-compatible with
|
||||
the shape ``[B, N, T_q, T_kv]``. If an additive mask is given its
|
||||
type must promote to the promoted type of ``q``, ``k``, and ``v``.
|
||||
sinks (array, optional): An optional array of attention sinks.
|
||||
Default: ``None``.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
|
||||
|
@@ -6,7 +6,7 @@ import mlx_tests
|
||||
import numpy as np
|
||||
|
||||
|
||||
def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
||||
def mlx_ref_attn(q, k, v, scale=1.0, mask=None, sinks=None):
|
||||
q_dtype = q.dtype
|
||||
q = q * mx.array(scale, q_dtype)
|
||||
n_q_heads = q.shape[-3]
|
||||
@@ -23,7 +23,6 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
||||
v = mx.expand_dims(v, 2)
|
||||
|
||||
scores = q @ mx.swapaxes(k, -1, -2)
|
||||
|
||||
if mask is not None:
|
||||
|
||||
if mask == "causal":
|
||||
@@ -43,7 +42,18 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
||||
else:
|
||||
scores += mask
|
||||
|
||||
if sinks is not None:
|
||||
sinks = mx.expand_dims(sinks, (0, 2, 3))
|
||||
if n_repeats > 1:
|
||||
sinks = mx.unflatten(sinks, 1, (n_kv_heads, n_repeats))
|
||||
score_shape = list(scores.shape)
|
||||
score_shape[-1] = 1
|
||||
sinks = mx.broadcast_to(sinks, score_shape)
|
||||
scores = mx.concatenate([sinks, scores], axis=-1)
|
||||
|
||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||
if sinks is not None:
|
||||
scores = scores[..., 1:]
|
||||
|
||||
out = scores @ v
|
||||
if n_repeats > 1:
|
||||
@@ -158,7 +168,7 @@ class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase):
|
||||
|
||||
Dk = 64
|
||||
|
||||
if self.is_apple_silicon:
|
||||
if self.is_apple_silicon or mx.cuda.is_available():
|
||||
dtypes.append(np.half)
|
||||
|
||||
for SEQUENCE_LENGTH in [63, 129, 400]:
|
||||
@@ -230,7 +240,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
B = 1
|
||||
H = 32
|
||||
dtypes = [np.float32]
|
||||
if self.is_apple_silicon:
|
||||
if self.is_apple_silicon or mx.cuda.is_available():
|
||||
dtypes.append(np.half)
|
||||
|
||||
for SEQUENCE_LENGTH in [1, 7, 9, 32, 63, 67, 129, 400, 2000]:
|
||||
@@ -400,15 +410,30 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
|
||||
def test_fully_masked(self):
|
||||
Lkv = 8
|
||||
mask = mx.array(False)
|
||||
masks = [mx.array(False), mx.array(-float("inf"))]
|
||||
for mask in masks:
|
||||
for D in [4, 128]:
|
||||
for Lq in [1, 8]:
|
||||
q = mx.random.normal(shape=(1, 4, Lq, D))
|
||||
k = mx.random.normal(shape=(1, 4, Lkv, D))
|
||||
v = mx.random.normal(shape=(1, 4, Lkv, D))
|
||||
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, mask=mask, scale=1
|
||||
)
|
||||
self.assertTrue(mx.all(mx.isnan(out)))
|
||||
|
||||
def test_inf_score(self):
|
||||
Lkv = 8
|
||||
for D in [4, 128]:
|
||||
for Lq in [1, 8]:
|
||||
q = mx.random.normal(shape=(1, 4, Lq, D))
|
||||
k = mx.random.normal(shape=(1, 4, Lkv, D))
|
||||
q = mx.ones(shape=(1, 4, Lq, D))
|
||||
k = mx.ones(shape=(1, 4, Lkv, D))
|
||||
v = mx.random.normal(shape=(1, 4, Lkv, D))
|
||||
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1)
|
||||
self.assertTrue(mx.all(mx.isnan(out)))
|
||||
k[..., 0, :] = -float("inf")
|
||||
ref = mlx_primitives_sdpa(q, k, v, scale=1, mask=None)
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
def test_fast_sdpa_few_query(self):
|
||||
D = 64
|
||||
@@ -674,6 +699,51 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
||||
self.assertFalse(mx.isnan(out).any().item())
|
||||
self.assertLessEqual(mx.abs(out - expected).max().item(), 1e-4)
|
||||
|
||||
def test_sdpa_attention_sinks(self):
|
||||
B = 2
|
||||
N_q = N_kv = 8
|
||||
T_q = T_kv = 128
|
||||
D = 64
|
||||
|
||||
q = mx.random.normal(shape=(B, N_q, T_q, D))
|
||||
k = mx.random.normal(shape=(B, N_kv, T_kv, D))
|
||||
v = mx.random.normal(shape=(B, N_kv, T_kv, D))
|
||||
scale = D**-0.5
|
||||
|
||||
# sinks should promote to correct type
|
||||
sinks = mx.random.normal(shape=(N_q,))
|
||||
with self.assertRaises(ValueError):
|
||||
mx.fast.scaled_dot_product_attention(
|
||||
q.astype(mx.float16),
|
||||
k.astype(mx.float16),
|
||||
v.astype(mx.float16),
|
||||
scale=scale,
|
||||
sinks=sinks,
|
||||
)
|
||||
|
||||
# Wrong shapes
|
||||
sinks = mx.random.normal(shape=(N_q + 1,))
|
||||
with self.assertRaises(ValueError):
|
||||
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, sinks=sinks)
|
||||
|
||||
sinks = mx.random.normal(shape=())
|
||||
with self.assertRaises(ValueError):
|
||||
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, sinks=sinks)
|
||||
|
||||
for T_kv in [128, 4096]:
|
||||
for T_q in [1, 128]:
|
||||
for N_kv in [2, 8]:
|
||||
q = mx.random.normal(shape=(B, N_q, T_q, D))
|
||||
k = mx.random.normal(shape=(B, N_kv, T_kv, D))
|
||||
v = mx.random.normal(shape=(B, N_kv, T_kv, D))
|
||||
sinks = 10 * mx.random.normal(shape=(N_q,))
|
||||
|
||||
expected = mlx_ref_attn(q, k, v, scale, sinks=sinks)
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=scale, sinks=sinks
|
||||
)
|
||||
self.assertTrue(mx.allclose(out, expected, atol=1e-5))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner(failfast=True)
|
||||
|
Reference in New Issue
Block a user