mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 07:30:06 +08:00
length masking
This commit is contained in:
parent
845efddc8c
commit
c5ce9a31f2
@ -23,7 +23,12 @@ class BaseModelArgs:
|
||||
)
|
||||
|
||||
|
||||
def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None):
|
||||
def create_causal_mask(
|
||||
N: int,
|
||||
offset: int = 0,
|
||||
window_size: Optional[int] = None,
|
||||
lengths: Optional[mx.array] = None,
|
||||
):
|
||||
rinds = mx.arange(offset + N)
|
||||
linds = mx.arange(offset, offset + N) if offset else rinds
|
||||
linds = linds[:, None]
|
||||
@ -31,12 +36,18 @@ def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = Non
|
||||
mask = linds < rinds
|
||||
if window_size is not None:
|
||||
mask = mask | (linds > rinds + window_size)
|
||||
if lengths is not None:
|
||||
mask = mx.repeat(mask[None], lengths.shape[0], axis=0)
|
||||
lengths = lengths[:, None, None]
|
||||
mask = mask | (rinds[None] >= lengths)
|
||||
mask = mask[:, None]
|
||||
return mask * -1e9
|
||||
|
||||
|
||||
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
||||
T = h.shape[1]
|
||||
if T > 1:
|
||||
lengths = None
|
||||
window_size = None
|
||||
offset = 0
|
||||
if cache is not None and cache[0] is not None:
|
||||
@ -46,7 +57,8 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
||||
window_size = c.max_size
|
||||
else:
|
||||
offset = c.offset
|
||||
mask = create_causal_mask(T, offset, window_size=window_size)
|
||||
lengths = getattr(c, "lengths", None)
|
||||
mask = create_causal_mask(T, offset, window_size=window_size, lengths=lengths)
|
||||
mask = mask.astype(h.dtype)
|
||||
else:
|
||||
mask = None
|
||||
|
@ -10,6 +10,7 @@ from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
def make_prompt_cache(
|
||||
model: nn.Module,
|
||||
max_kv_size: Optional[int] = None,
|
||||
lengths: Optional[mx.array] = None,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Construct the model's cache for use when cgeneration.
|
||||
@ -22,17 +23,22 @@ def make_prompt_cache(
|
||||
max_kv_size (Optional[int]): If provided and the model does not have a
|
||||
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
|
||||
size of ``max_kv_size``
|
||||
lengths (Optional[array]): If provided these sequence lengths will be
|
||||
used mask the KV cache. Useful for batch inputs.
|
||||
"""
|
||||
if hasattr(model, "make_cache"):
|
||||
return model.make_cache()
|
||||
|
||||
num_layers = len(model.layers)
|
||||
if max_kv_size is not None:
|
||||
return [
|
||||
cache = [
|
||||
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
|
||||
]
|
||||
else:
|
||||
return [KVCache() for _ in range(num_layers)]
|
||||
cache = [KVCache() for _ in range(num_layers)]
|
||||
|
||||
cache[0].lengths = lengths
|
||||
return cache
|
||||
|
||||
|
||||
def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
|
||||
@ -135,6 +141,7 @@ class QuantizedKVCache(_BaseCache):
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
self.step = 256
|
||||
self.lengths = None
|
||||
self.group_size = group_size
|
||||
self.bits = bits
|
||||
|
||||
@ -217,6 +224,7 @@ class KVCache(_BaseCache):
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
self.step = 256
|
||||
self.lengths = None
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
prev = self.offset
|
||||
@ -285,6 +293,7 @@ class RotatingKVCache(_BaseCache):
|
||||
self.offset = 0
|
||||
self.max_size = max_size
|
||||
self.step = step
|
||||
self.lengths = None
|
||||
self._idx = 0
|
||||
|
||||
def _trim(self, trim_size, v, append=None):
|
||||
|
@ -5,6 +5,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_map
|
||||
from mlx_lm.models import rope_utils
|
||||
from mlx_lm.models.base import create_attention_mask
|
||||
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
|
||||
|
||||
|
||||
@ -128,6 +129,25 @@ class TestModels(unittest.TestCase):
|
||||
self.assertEqual(cache.offset, 22)
|
||||
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
|
||||
|
||||
def test_cache_lengths(self):
|
||||
mx.random.seed(8)
|
||||
B, N_q, T_q, N_kv, T_kv, D = (4, 8, 3, 2, 3, 2)
|
||||
lengths = mx.array([1, 2, 3, 1])
|
||||
h = mx.random.uniform(shape=(B, T_q, D))
|
||||
q = mx.random.uniform(shape=(B, N_q, T_q, D))
|
||||
k = mx.random.uniform(shape=(B, N_kv, T_kv, D))
|
||||
v = k
|
||||
cache = [KVCache()]
|
||||
cache[0].lengths = lengths
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
out1 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||
q[1, :, 2:] = mx.ones_like(q[1, :, 2:])
|
||||
k[1, :, 2:] = mx.ones_like(k[1, :, 2:])
|
||||
v[1, :, 2:] = mx.ones_like(v[1, :, 2:])
|
||||
out2 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||
self.assertTrue(mx.allclose(out1[1, :, :2], out2[1, :, :2]))
|
||||
|
||||
def test_rope(self):
|
||||
rope = rope_utils.initialize_rope(32, base=100, traditional=False)
|
||||
self.assertTrue(isinstance(rope, nn.RoPE))
|
||||
|
Loading…
Reference in New Issue
Block a user