diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index f02f49b1..538bc51c 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -23,7 +23,12 @@ class BaseModelArgs: ) -def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None): +def create_causal_mask( + N: int, + offset: int = 0, + window_size: Optional[int] = None, + lengths: Optional[mx.array] = None, +): rinds = mx.arange(offset + N) linds = mx.arange(offset, offset + N) if offset else rinds linds = linds[:, None] @@ -31,12 +36,18 @@ def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = Non mask = linds < rinds if window_size is not None: mask = mask | (linds > rinds + window_size) + if lengths is not None: + mask = mx.repeat(mask[None], lengths.shape[0], axis=0) + lengths = lengths[:, None, None] + mask = mask | (rinds[None] >= lengths) + mask = mask[:, None] return mask * -1e9 def create_attention_mask(h: mx.array, cache: Optional[Any] = None): T = h.shape[1] if T > 1: + lengths = None window_size = None offset = 0 if cache is not None and cache[0] is not None: @@ -46,7 +57,8 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None): window_size = c.max_size else: offset = c.offset - mask = create_causal_mask(T, offset, window_size=window_size) + lengths = getattr(c, "lengths", None) + mask = create_causal_mask(T, offset, window_size=window_size, lengths=lengths) mask = mask.astype(h.dtype) else: mask = None diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index 14026f0c..1e311381 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -10,6 +10,7 @@ from mlx.utils import tree_flatten, tree_map, tree_unflatten def make_prompt_cache( model: nn.Module, max_kv_size: Optional[int] = None, + lengths: Optional[mx.array] = None, ) -> List[Any]: """ Construct the model's cache for use when cgeneration. @@ -22,17 +23,22 @@ def make_prompt_cache( max_kv_size (Optional[int]): If provided and the model does not have a ``make_cache`` method, a ``RotatingKVCache`` is used with a maximum size of ``max_kv_size`` + lengths (Optional[array]): If provided these sequence lengths will be + used mask the KV cache. Useful for batch inputs. """ if hasattr(model, "make_cache"): return model.make_cache() num_layers = len(model.layers) if max_kv_size is not None: - return [ + cache = [ RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers) ] else: - return [KVCache() for _ in range(num_layers)] + cache = [KVCache() for _ in range(num_layers)] + + cache[0].lengths = lengths + return cache def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}): @@ -135,6 +141,7 @@ class QuantizedKVCache(_BaseCache): self.values = None self.offset = 0 self.step = 256 + self.lengths = None self.group_size = group_size self.bits = bits @@ -217,6 +224,7 @@ class KVCache(_BaseCache): self.values = None self.offset = 0 self.step = 256 + self.lengths = None def update_and_fetch(self, keys, values): prev = self.offset @@ -285,6 +293,7 @@ class RotatingKVCache(_BaseCache): self.offset = 0 self.max_size = max_size self.step = step + self.lengths = None self._idx = 0 def _trim(self, trim_size, v, append=None): diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 3097c522..6ae7d803 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -5,6 +5,7 @@ import mlx.core as mx import mlx.nn as nn from mlx.utils import tree_map from mlx_lm.models import rope_utils +from mlx_lm.models.base import create_attention_mask from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache @@ -128,6 +129,25 @@ class TestModels(unittest.TestCase): self.assertEqual(cache.offset, 22) self.assertTrue(mx.allclose(x, k[..., -2:, :])) + def test_cache_lengths(self): + mx.random.seed(8) + B, N_q, T_q, N_kv, T_kv, D = (4, 8, 3, 2, 3, 2) + lengths = mx.array([1, 2, 3, 1]) + h = mx.random.uniform(shape=(B, T_q, D)) + q = mx.random.uniform(shape=(B, N_q, T_q, D)) + k = mx.random.uniform(shape=(B, N_kv, T_kv, D)) + v = k + cache = [KVCache()] + cache[0].lengths = lengths + mask = create_attention_mask(h, cache) + + out1 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask) + q[1, :, 2:] = mx.ones_like(q[1, :, 2:]) + k[1, :, 2:] = mx.ones_like(k[1, :, 2:]) + v[1, :, 2:] = mx.ones_like(v[1, :, 2:]) + out2 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask) + self.assertTrue(mx.allclose(out1[1, :, :2], out2[1, :, :2])) + def test_rope(self): rope = rope_utils.initialize_rope(32, base=100, traditional=False) self.assertTrue(isinstance(rope, nn.RoPE))