From ed060a7c5c7132eb9c58df1f829b6f6a7f698ffa Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 4 Oct 2024 07:43:13 -0700 Subject: [PATCH] fix rotating kv cache for chat use case --- llms/mlx_lm/_version.py | 2 +- llms/mlx_lm/models/base.py | 52 ++++++++++++++++++++++---------------- llms/tests/test_models.py | 40 +++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 23 deletions(-) diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py index 8110c823..70239db6 100644 --- a/llms/mlx_lm/_version.py +++ b/llms/mlx_lm/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.18.2" +__version__ = "0.19.1" diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index dc19dd05..75f19642 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -79,28 +79,29 @@ class RotatingKVCache: to_cat.append(append) return mx.concatenate(to_cat, axis=2) - def update_and_fetch(self, keys, values): - prev = self.offset - B, _, S = keys.shape[:3] + def _update_concat(self, keys, values): + if self.keys is None: + self.keys = keys + self.values = values + else: + if self._idx < self.keys.shape[2]: + self.keys = self.keys[..., : self._idx, :] + self.values = self.values[..., : self._idx, :] - # Prefill mode - if S > 1: - if self.keys is None: - self.keys = keys - self.values = values - else: - # The largest size is self.max_size + S - 1 to ensure - # every token gets at least self.max_size context - trim_size = self.keys.shape[2] - self.max_size + 1 - self.keys = self._trim(trim_size, self.keys, keys) - self.values = self._trim(trim_size, self.values, values) - self.offset += S - self._idx = self.keys.shape[2] - return self.keys, self.values + # The largest size is self.max_size + S - 1 to ensure + # every token gets at least self.max_size context + trim_size = self._idx - self.max_size + 1 + self.keys = self._trim(trim_size, self.keys, keys) + self.values = self._trim(trim_size, self.values, values) + self.offset += keys.shape[2] + self._idx = self.keys.shape[2] + return self.keys, self.values - # Generation mode + def _update_in_place(self, keys, values): # May not have hit the max size yet, so potentially # keep growing the cache + B, _, S = keys.shape[:3] + prev = self.offset if self.keys is None or ( prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size ): @@ -128,16 +129,23 @@ class RotatingKVCache: self._idx = self.keep # Assign - self.keys[..., self._idx : self._idx + 1, :] = keys - self.values[..., self._idx : self._idx + 1, :] = values - self.offset += 1 - self._idx += 1 + self.keys[..., self._idx : self._idx + S, :] = keys + self.values[..., self._idx : self._idx + S, :] = values + self.offset += S + self._idx += S # If the buffer is not full, slice off the end if self.offset < self.max_size: return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] return self.keys, self.values + def update_and_fetch(self, keys, values): + S = keys.shape[2] + if S == 1 or (self.keys is not None and S < (self.keys.shape[2] - self._idx)): + return self._update_in_place(keys, values) + + return self._update_concat(keys, values) + @property def state(self): return self.keys, self.values diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index cd7e7fd0..cb676a47 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -88,6 +88,46 @@ class TestModels(unittest.TestCase): if idx >= 8: idx = 2 + def test_rotating_kv_cache_chat_mode(self): + # Test that the rotating kv cache can handle + # alternating prompt/prefill with generation + d = 4 + h = 2 + cache = RotatingKVCache(d, h, max_size=18, step=4) + + x = mx.random.uniform(shape=(1, h, 8, d)) + k, v = cache.update_and_fetch(x, x) + self.assertEqual(k.shape[2], 8) + self.assertEqual(cache.offset, 8) + + x = mx.random.uniform(shape=(1, h, 1, d)) + k, v = cache.update_and_fetch(x, x) + self.assertEqual(k.shape[2], 9) + self.assertEqual(cache.offset, 9) + self.assertTrue(mx.allclose(x, k[..., 8:9, :])) + + x = mx.random.uniform(shape=(1, h, 2, d)) + k, v = cache.update_and_fetch(x, x) + self.assertEqual(k.shape[2], 11) + self.assertEqual(cache.offset, 11) + self.assertTrue(mx.allclose(x, k[..., 9:11, :])) + + x = mx.random.uniform(shape=(1, h, 3, d)) + k, v = cache.update_and_fetch(x, x) + self.assertEqual(k.shape[2], 14) + self.assertEqual(cache.offset, 14) + self.assertTrue(mx.allclose(x, k[..., 11:14, :])) + + x = mx.random.uniform(shape=(1, h, 6, d)) + k, v = cache.update_and_fetch(x, x) + self.assertEqual(cache.offset, 20) + self.assertTrue(mx.allclose(x, k[..., -6:, :])) + + x = mx.random.uniform(shape=(1, h, 2, d)) + k, v = cache.update_and_fetch(x, x) + self.assertEqual(cache.offset, 22) + self.assertTrue(mx.allclose(x, k[..., -2:, :])) + def model_test_runner(self, model, model_type, vocab_size, num_layers): self.assertEqual(len(model.layers), num_layers)