diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index cda41c79..f02f49b1 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -42,7 +42,7 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None): if cache is not None and cache[0] is not None: c = cache[0] if hasattr(c, "max_size"): - offset = min(c.max_size - 1, c.offset) + offset = min(c.max_size, c.offset) window_size = c.max_size else: offset = c.offset diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index 1cd5289d..14026f0c 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -325,9 +325,9 @@ class RotatingKVCache(_BaseCache): self.keys = self._temporal_order(self.keys) self.values = self._temporal_order(self.values) - # The largest size is self.max_size + S - 1 to ensure + # The largest size is self.max_size + S to ensure # every token gets at least self.max_size context - trim_size = self._idx - self.max_size + 1 + trim_size = self._idx - self.max_size self.keys = self._trim(trim_size, self.keys, keys) self.values = self._trim(trim_size, self.values, values) self.offset += keys.shape[2]