From ed9e81dd581a9505e677e12c025137d5326fe6df Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 5 Nov 2024 10:24:24 -0800 Subject: [PATCH] Fix rotating kv cache size (#1093) --- llms/mlx_lm/models/base.py | 2 +- llms/mlx_lm/models/cache.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index cda41c79..f02f49b1 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -42,7 +42,7 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None): if cache is not None and cache[0] is not None: c = cache[0] if hasattr(c, "max_size"): - offset = min(c.max_size - 1, c.offset) + offset = min(c.max_size, c.offset) window_size = c.max_size else: offset = c.offset diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index 1cd5289d..14026f0c 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -325,9 +325,9 @@ class RotatingKVCache(_BaseCache): self.keys = self._temporal_order(self.keys) self.values = self._temporal_order(self.values) - # The largest size is self.max_size + S - 1 to ensure + # The largest size is self.max_size + S to ensure # every token gets at least self.max_size context - trim_size = self._idx - self.max_size + 1 + trim_size = self._idx - self.max_size self.keys = self._trim(trim_size, self.keys, keys) self.values = self._trim(trim_size, self.values, values) self.offset += keys.shape[2]