fix rotating kv cache for chat use case

This commit is contained in:
Awni Hannun
2024-10-04 07:43:13 -07:00
parent 9bc53fc210
commit ed060a7c5c
3 changed files with 71 additions and 23 deletions

View File

@@ -79,28 +79,29 @@ class RotatingKVCache:
to_cat.append(append)
return mx.concatenate(to_cat, axis=2)
def update_and_fetch(self, keys, values):
prev = self.offset
B, _, S = keys.shape[:3]
def _update_concat(self, keys, values):
if self.keys is None:
self.keys = keys
self.values = values
else:
if self._idx < self.keys.shape[2]:
self.keys = self.keys[..., : self._idx, :]
self.values = self.values[..., : self._idx, :]
# Prefill mode
if S > 1:
if self.keys is None:
self.keys = keys
self.values = values
else:
# The largest size is self.max_size + S - 1 to ensure
# every token gets at least self.max_size context
trim_size = self.keys.shape[2] - self.max_size + 1
self.keys = self._trim(trim_size, self.keys, keys)
self.values = self._trim(trim_size, self.values, values)
self.offset += S
self._idx = self.keys.shape[2]
return self.keys, self.values
# The largest size is self.max_size + S - 1 to ensure
# every token gets at least self.max_size context
trim_size = self._idx - self.max_size + 1
self.keys = self._trim(trim_size, self.keys, keys)
self.values = self._trim(trim_size, self.values, values)
self.offset += keys.shape[2]
self._idx = self.keys.shape[2]
return self.keys, self.values
# Generation mode
def _update_in_place(self, keys, values):
# May not have hit the max size yet, so potentially
# keep growing the cache
B, _, S = keys.shape[:3]
prev = self.offset
if self.keys is None or (
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
):
@@ -128,16 +129,23 @@ class RotatingKVCache:
self._idx = self.keep
# Assign
self.keys[..., self._idx : self._idx + 1, :] = keys
self.values[..., self._idx : self._idx + 1, :] = values
self.offset += 1
self._idx += 1
self.keys[..., self._idx : self._idx + S, :] = keys
self.values[..., self._idx : self._idx + S, :] = values
self.offset += S
self._idx += S
# If the buffer is not full, slice off the end
if self.offset < self.max_size:
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
return self.keys, self.values
def update_and_fetch(self, keys, values):
S = keys.shape[2]
if S == 1 or (self.keys is not None and S < (self.keys.shape[2] - self._idx)):
return self._update_in_place(keys, values)
return self._update_concat(keys, values)
@property
def state(self):
return self.keys, self.values