mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Support non incremental kv cache growth (#766)
This commit is contained in:
@@ -16,12 +16,15 @@ class KVCache:
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
prev = self.offset
|
||||
if prev % self.step == 0:
|
||||
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
|
||||
n_steps = (self.step + keys.shape[2] - 1) // self.step
|
||||
shape = (1, self.n_kv_heads, n_steps * self.step, self.head_dim)
|
||||
new_k = mx.zeros(shape, keys.dtype)
|
||||
new_v = mx.zeros(shape, values.dtype)
|
||||
if self.keys is not None:
|
||||
if prev % self.step != 0:
|
||||
self.keys = self.keys[..., :prev, :]
|
||||
self.values = self.values[..., :prev, :]
|
||||
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
||||
self.values = mx.concatenate([self.values, new_v], axis=2)
|
||||
else:
|
||||
|
Reference in New Issue
Block a user