Support non incremental kv cache growth (#766)

This commit is contained in:
Awni Hannun
2024-05-15 12:56:24 -07:00
committed by GitHub
parent 1a86d985d9
commit 69181e0058
2 changed files with 24 additions and 1 deletions

View File

@@ -16,12 +16,15 @@ class KVCache:
def update_and_fetch(self, keys, values):
prev = self.offset
if prev % self.step == 0:
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
n_steps = (self.step + keys.shape[2] - 1) // self.step
shape = (1, self.n_kv_heads, n_steps * self.step, self.head_dim)
new_k = mx.zeros(shape, keys.dtype)
new_v = mx.zeros(shape, values.dtype)
if self.keys is not None:
if prev % self.step != 0:
self.keys = self.keys[..., :prev, :]
self.values = self.values[..., :prev, :]
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else: