fix rotating kv cache for chat use case

This commit is contained in:
Awni Hannun 2024-10-04 07:43:13 -07:00
parent 9bc53fc210
commit ed060a7c5c
3 changed files with 71 additions and 23 deletions

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
__version__ = "0.18.2"
__version__ = "0.19.1"

View File

@ -79,28 +79,29 @@ class RotatingKVCache:
to_cat.append(append)
return mx.concatenate(to_cat, axis=2)
def update_and_fetch(self, keys, values):
prev = self.offset
B, _, S = keys.shape[:3]
def _update_concat(self, keys, values):
if self.keys is None:
self.keys = keys
self.values = values
else:
if self._idx < self.keys.shape[2]:
self.keys = self.keys[..., : self._idx, :]
self.values = self.values[..., : self._idx, :]
# Prefill mode
if S > 1:
if self.keys is None:
self.keys = keys
self.values = values
else:
# The largest size is self.max_size + S - 1 to ensure
# every token gets at least self.max_size context
trim_size = self.keys.shape[2] - self.max_size + 1
self.keys = self._trim(trim_size, self.keys, keys)
self.values = self._trim(trim_size, self.values, values)
self.offset += S
self._idx = self.keys.shape[2]
return self.keys, self.values
# The largest size is self.max_size + S - 1 to ensure
# every token gets at least self.max_size context
trim_size = self._idx - self.max_size + 1
self.keys = self._trim(trim_size, self.keys, keys)
self.values = self._trim(trim_size, self.values, values)
self.offset += keys.shape[2]
self._idx = self.keys.shape[2]
return self.keys, self.values
# Generation mode
def _update_in_place(self, keys, values):
# May not have hit the max size yet, so potentially
# keep growing the cache
B, _, S = keys.shape[:3]
prev = self.offset
if self.keys is None or (
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
):
@ -128,16 +129,23 @@ class RotatingKVCache:
self._idx = self.keep
# Assign
self.keys[..., self._idx : self._idx + 1, :] = keys
self.values[..., self._idx : self._idx + 1, :] = values
self.offset += 1
self._idx += 1
self.keys[..., self._idx : self._idx + S, :] = keys
self.values[..., self._idx : self._idx + S, :] = values
self.offset += S
self._idx += S
# If the buffer is not full, slice off the end
if self.offset < self.max_size:
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
return self.keys, self.values
def update_and_fetch(self, keys, values):
S = keys.shape[2]
if S == 1 or (self.keys is not None and S < (self.keys.shape[2] - self._idx)):
return self._update_in_place(keys, values)
return self._update_concat(keys, values)
@property
def state(self):
return self.keys, self.values

View File

@ -88,6 +88,46 @@ class TestModels(unittest.TestCase):
if idx >= 8:
idx = 2
def test_rotating_kv_cache_chat_mode(self):
# Test that the rotating kv cache can handle
# alternating prompt/prefill with generation
d = 4
h = 2
cache = RotatingKVCache(d, h, max_size=18, step=4)
x = mx.random.uniform(shape=(1, h, 8, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(k.shape[2], 8)
self.assertEqual(cache.offset, 8)
x = mx.random.uniform(shape=(1, h, 1, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(k.shape[2], 9)
self.assertEqual(cache.offset, 9)
self.assertTrue(mx.allclose(x, k[..., 8:9, :]))
x = mx.random.uniform(shape=(1, h, 2, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(k.shape[2], 11)
self.assertEqual(cache.offset, 11)
self.assertTrue(mx.allclose(x, k[..., 9:11, :]))
x = mx.random.uniform(shape=(1, h, 3, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(k.shape[2], 14)
self.assertEqual(cache.offset, 14)
self.assertTrue(mx.allclose(x, k[..., 11:14, :]))
x = mx.random.uniform(shape=(1, h, 6, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(cache.offset, 20)
self.assertTrue(mx.allclose(x, k[..., -6:, :]))
x = mx.random.uniform(shape=(1, h, 2, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(cache.offset, 22)
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
def model_test_runner(self, model, model_type, vocab_size, num_layers):
self.assertEqual(len(model.layers), num_layers)