mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
causal vector sdpa (#2018)
* causal vector sdpa * get rid of memory threshold
This commit is contained in:
@@ -131,7 +131,6 @@ void init_fast(nb::module_& parent_module) {
|
||||
nb::kw_only(),
|
||||
"scale"_a,
|
||||
"mask"_a = nb::none(),
|
||||
"memory_efficient_threshold"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, str, array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
@@ -164,10 +163,10 @@ void init_fast(nb::module_& parent_module) {
|
||||
k (array): Keys with shape ``[B, N_kv, T_kv, D]``.
|
||||
v (array): Values with shape ``[B, N_kv, T_kv, D]``.
|
||||
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
|
||||
mask (Union[None, str, array], optional): A causal, boolean or additive
|
||||
mask to apply to the query-key scores. The mask can have at most 4
|
||||
dimensions and must be broadcast-compatible with the shape
|
||||
``[B, N, T_q, T_kv]``. If an additive mask is given its type must
|
||||
mask (Union[None, str, array], optional): A causal, boolean or additive
|
||||
mask to apply to the query-key scores. The mask can have at most 4
|
||||
dimensions and must be broadcast-compatible with the shape
|
||||
``[B, N, T_q, T_kv]``. If an additive mask is given its type must
|
||||
promote to the promoted type of ``q``, ``k``, and ``v``.
|
||||
Returns:
|
||||
array: The output array.
|
||||
|
@@ -95,7 +95,13 @@ def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
|
||||
def mlx_primitives_sdpa(q, k, v, scale, mask=None):
|
||||
p = (q * scale) @ k.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
if mask.dtype == mx.bool_:
|
||||
if mask == "causal":
|
||||
q_offset = max(0, k.shape[2] - q.shape[2])
|
||||
q_indices = mx.arange(q_offset, q_offset + q.shape[2])
|
||||
k_indices = mx.arange(k.shape[2])
|
||||
mask = q_indices[:, None] >= k_indices[None]
|
||||
p = mx.where(mask, p, mx.finfo(mx.float32).min)
|
||||
elif mask.dtype == mx.bool_:
|
||||
p = mx.where(mask, p, mx.finfo(mx.float32).min)
|
||||
else:
|
||||
p += mask
|
||||
@@ -176,7 +182,10 @@ class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase):
|
||||
|
||||
reference = mlx_primitives_sdpa_with_gqa(q_mlx, k_mlx, v_mlx, scale)
|
||||
o_mlx = mx.fast.scaled_dot_product_attention(
|
||||
q_mlx, k_mlx, v_mlx, scale=scale, memory_efficient_threshold=2
|
||||
q_mlx,
|
||||
k_mlx,
|
||||
v_mlx,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
self.assertListEqual(list(reference.shape), list(o_mlx.shape))
|
||||
@@ -342,6 +351,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
mx.array([True] * (L - 10) + [False] * 10),
|
||||
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
|
||||
"causal",
|
||||
]
|
||||
for m in masks:
|
||||
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
|
||||
@@ -366,6 +376,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
mx.array([True] * (L - 10) + [False] * 10),
|
||||
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
|
||||
"causal",
|
||||
]
|
||||
for m in masks:
|
||||
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
|
||||
@@ -396,6 +407,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
mx.array([True] * (L - 10) + [False] * 10),
|
||||
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
|
||||
"causal",
|
||||
]
|
||||
for m in masks:
|
||||
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
|
||||
@@ -420,6 +432,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
mx.array([True] * (L - 10) + [False] * 10),
|
||||
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
|
||||
"causal",
|
||||
]
|
||||
for m in masks:
|
||||
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
|
||||
|
Reference in New Issue
Block a user