Support non incremental kv cache growth (#766)

This commit is contained in:
Awni Hannun 2024-05-15 12:56:24 -07:00 committed by GitHub
parent 1a86d985d9
commit 69181e0058
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 1 deletions

View File

@ -16,12 +16,15 @@ class KVCache:
def update_and_fetch(self, keys, values):
prev = self.offset
if prev % self.step == 0:
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
n_steps = (self.step + keys.shape[2] - 1) // self.step
shape = (1, self.n_kv_heads, n_steps * self.step, self.head_dim)
new_k = mx.zeros(shape, keys.dtype)
new_v = mx.zeros(shape, values.dtype)
if self.keys is not None:
if prev % self.step != 0:
self.keys = self.keys[..., :prev, :]
self.values = self.values[..., :prev, :]
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:

View File

@ -9,6 +9,26 @@ from mlx_lm.models.base import KVCache
class TestModels(unittest.TestCase):
def test_kv_cache(self):
cache = KVCache(32, 4)
k = mx.ones((1, 4, 1, 32), mx.float16)
v = mx.ones((1, 4, 1, 32), mx.float16)
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up, k))
self.assertTrue(mx.array_equal(v_up, v))
self.assertEqual(cache.offset, 1)
k = mx.ones((1, 4, cache.step, 32), mx.float16)
v = mx.ones((1, 4, cache.step, 32), mx.float16)
k_up, v_up = cache.update_and_fetch(k, v)
expected = mx.ones((1, 4, cache.step + 1, 32), mx.float16)
self.assertTrue(mx.array_equal(k_up, expected))
self.assertTrue(mx.array_equal(v_up, expected))
self.assertEqual(cache.offset, cache.step + 1)
def model_test_runner(self, model, model_type, vocab_size, num_layers):
self.assertEqual(len(model.layers), num_layers)