diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index 15002f8f..1e184294 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -16,12 +16,15 @@ class KVCache: def update_and_fetch(self, keys, values): prev = self.offset - if prev % self.step == 0: + if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]: n_steps = (self.step + keys.shape[2] - 1) // self.step shape = (1, self.n_kv_heads, n_steps * self.step, self.head_dim) new_k = mx.zeros(shape, keys.dtype) new_v = mx.zeros(shape, values.dtype) if self.keys is not None: + if prev % self.step != 0: + self.keys = self.keys[..., :prev, :] + self.values = self.values[..., :prev, :] self.keys = mx.concatenate([self.keys, new_k], axis=2) self.values = mx.concatenate([self.values, new_v], axis=2) else: diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index eb5a0625..af71b9d0 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -9,6 +9,26 @@ from mlx_lm.models.base import KVCache class TestModels(unittest.TestCase): + def test_kv_cache(self): + cache = KVCache(32, 4) + + k = mx.ones((1, 4, 1, 32), mx.float16) + v = mx.ones((1, 4, 1, 32), mx.float16) + + k_up, v_up = cache.update_and_fetch(k, v) + self.assertTrue(mx.array_equal(k_up, k)) + self.assertTrue(mx.array_equal(v_up, v)) + self.assertEqual(cache.offset, 1) + + k = mx.ones((1, 4, cache.step, 32), mx.float16) + v = mx.ones((1, 4, cache.step, 32), mx.float16) + k_up, v_up = cache.update_and_fetch(k, v) + + expected = mx.ones((1, 4, cache.step + 1, 32), mx.float16) + self.assertTrue(mx.array_equal(k_up, expected)) + self.assertTrue(mx.array_equal(v_up, expected)) + self.assertEqual(cache.offset, cache.step + 1) + def model_test_runner(self, model, model_type, vocab_size, num_layers): self.assertEqual(len(model.layers), num_layers)