Fix rotating kv cache size (#1093)

This commit is contained in:
Angelos Katharopoulos 2024-11-05 10:24:24 -08:00 committed by GitHub
parent 6fd1f70f73
commit ed9e81dd58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 3 deletions

View File

@ -42,7 +42,7 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
if cache is not None and cache[0] is not None:
c = cache[0]
if hasattr(c, "max_size"):
offset = min(c.max_size - 1, c.offset)
offset = min(c.max_size, c.offset)
window_size = c.max_size
else:
offset = c.offset

View File

@ -325,9 +325,9 @@ class RotatingKVCache(_BaseCache):
self.keys = self._temporal_order(self.keys)
self.values = self._temporal_order(self.values)
# The largest size is self.max_size + S - 1 to ensure
# The largest size is self.max_size + S to ensure
# every token gets at least self.max_size context
trim_size = self._idx - self.max_size + 1
trim_size = self._idx - self.max_size
self.keys = self._trim(trim_size, self.keys, keys)
self.values = self._trim(trim_size, self.values, values)
self.offset += keys.shape[2]