length masking

This commit is contained in:
Alex Barron 2024-12-17 22:35:45 -08:00
parent 845efddc8c
commit c5ce9a31f2
3 changed files with 45 additions and 4 deletions

View File

@ -23,7 +23,12 @@ class BaseModelArgs:
)
def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None):
def create_causal_mask(
N: int,
offset: int = 0,
window_size: Optional[int] = None,
lengths: Optional[mx.array] = None,
):
rinds = mx.arange(offset + N)
linds = mx.arange(offset, offset + N) if offset else rinds
linds = linds[:, None]
@ -31,12 +36,18 @@ def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = Non
mask = linds < rinds
if window_size is not None:
mask = mask | (linds > rinds + window_size)
if lengths is not None:
mask = mx.repeat(mask[None], lengths.shape[0], axis=0)
lengths = lengths[:, None, None]
mask = mask | (rinds[None] >= lengths)
mask = mask[:, None]
return mask * -1e9
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
T = h.shape[1]
if T > 1:
lengths = None
window_size = None
offset = 0
if cache is not None and cache[0] is not None:
@ -46,7 +57,8 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
window_size = c.max_size
else:
offset = c.offset
mask = create_causal_mask(T, offset, window_size=window_size)
lengths = getattr(c, "lengths", None)
mask = create_causal_mask(T, offset, window_size=window_size, lengths=lengths)
mask = mask.astype(h.dtype)
else:
mask = None

View File

@ -10,6 +10,7 @@ from mlx.utils import tree_flatten, tree_map, tree_unflatten
def make_prompt_cache(
model: nn.Module,
max_kv_size: Optional[int] = None,
lengths: Optional[mx.array] = None,
) -> List[Any]:
"""
Construct the model's cache for use when cgeneration.
@ -22,17 +23,22 @@ def make_prompt_cache(
max_kv_size (Optional[int]): If provided and the model does not have a
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
size of ``max_kv_size``
lengths (Optional[array]): If provided these sequence lengths will be
used mask the KV cache. Useful for batch inputs.
"""
if hasattr(model, "make_cache"):
return model.make_cache()
num_layers = len(model.layers)
if max_kv_size is not None:
return [
cache = [
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
]
else:
return [KVCache() for _ in range(num_layers)]
cache = [KVCache() for _ in range(num_layers)]
cache[0].lengths = lengths
return cache
def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
@ -135,6 +141,7 @@ class QuantizedKVCache(_BaseCache):
self.values = None
self.offset = 0
self.step = 256
self.lengths = None
self.group_size = group_size
self.bits = bits
@ -217,6 +224,7 @@ class KVCache(_BaseCache):
self.values = None
self.offset = 0
self.step = 256
self.lengths = None
def update_and_fetch(self, keys, values):
prev = self.offset
@ -285,6 +293,7 @@ class RotatingKVCache(_BaseCache):
self.offset = 0
self.max_size = max_size
self.step = step
self.lengths = None
self._idx = 0
def _trim(self, trim_size, v, append=None):

View File

@ -5,6 +5,7 @@ import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_map
from mlx_lm.models import rope_utils
from mlx_lm.models.base import create_attention_mask
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
@ -128,6 +129,25 @@ class TestModels(unittest.TestCase):
self.assertEqual(cache.offset, 22)
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
def test_cache_lengths(self):
mx.random.seed(8)
B, N_q, T_q, N_kv, T_kv, D = (4, 8, 3, 2, 3, 2)
lengths = mx.array([1, 2, 3, 1])
h = mx.random.uniform(shape=(B, T_q, D))
q = mx.random.uniform(shape=(B, N_q, T_q, D))
k = mx.random.uniform(shape=(B, N_kv, T_kv, D))
v = k
cache = [KVCache()]
cache[0].lengths = lengths
mask = create_attention_mask(h, cache)
out1 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
q[1, :, 2:] = mx.ones_like(q[1, :, 2:])
k[1, :, 2:] = mx.ones_like(k[1, :, 2:])
v[1, :, 2:] = mx.ones_like(v[1, :, 2:])
out2 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
self.assertTrue(mx.allclose(out1[1, :, :2], out2[1, :, :2]))
def test_rope(self):
rope = rope_utils.initialize_rope(32, base=100, traditional=False)
self.assertTrue(isinstance(rope, nn.RoPE))