mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
fix rotating kv cache for chat use case
This commit is contained in:
@@ -1,3 +1,3 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
__version__ = "0.18.2"
|
__version__ = "0.19.1"
|
||||||
|
@@ -79,28 +79,29 @@ class RotatingKVCache:
|
|||||||
to_cat.append(append)
|
to_cat.append(append)
|
||||||
return mx.concatenate(to_cat, axis=2)
|
return mx.concatenate(to_cat, axis=2)
|
||||||
|
|
||||||
def update_and_fetch(self, keys, values):
|
def _update_concat(self, keys, values):
|
||||||
prev = self.offset
|
if self.keys is None:
|
||||||
B, _, S = keys.shape[:3]
|
self.keys = keys
|
||||||
|
self.values = values
|
||||||
|
else:
|
||||||
|
if self._idx < self.keys.shape[2]:
|
||||||
|
self.keys = self.keys[..., : self._idx, :]
|
||||||
|
self.values = self.values[..., : self._idx, :]
|
||||||
|
|
||||||
# Prefill mode
|
# The largest size is self.max_size + S - 1 to ensure
|
||||||
if S > 1:
|
# every token gets at least self.max_size context
|
||||||
if self.keys is None:
|
trim_size = self._idx - self.max_size + 1
|
||||||
self.keys = keys
|
self.keys = self._trim(trim_size, self.keys, keys)
|
||||||
self.values = values
|
self.values = self._trim(trim_size, self.values, values)
|
||||||
else:
|
self.offset += keys.shape[2]
|
||||||
# The largest size is self.max_size + S - 1 to ensure
|
self._idx = self.keys.shape[2]
|
||||||
# every token gets at least self.max_size context
|
return self.keys, self.values
|
||||||
trim_size = self.keys.shape[2] - self.max_size + 1
|
|
||||||
self.keys = self._trim(trim_size, self.keys, keys)
|
|
||||||
self.values = self._trim(trim_size, self.values, values)
|
|
||||||
self.offset += S
|
|
||||||
self._idx = self.keys.shape[2]
|
|
||||||
return self.keys, self.values
|
|
||||||
|
|
||||||
# Generation mode
|
def _update_in_place(self, keys, values):
|
||||||
# May not have hit the max size yet, so potentially
|
# May not have hit the max size yet, so potentially
|
||||||
# keep growing the cache
|
# keep growing the cache
|
||||||
|
B, _, S = keys.shape[:3]
|
||||||
|
prev = self.offset
|
||||||
if self.keys is None or (
|
if self.keys is None or (
|
||||||
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
|
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
|
||||||
):
|
):
|
||||||
@@ -128,16 +129,23 @@ class RotatingKVCache:
|
|||||||
self._idx = self.keep
|
self._idx = self.keep
|
||||||
|
|
||||||
# Assign
|
# Assign
|
||||||
self.keys[..., self._idx : self._idx + 1, :] = keys
|
self.keys[..., self._idx : self._idx + S, :] = keys
|
||||||
self.values[..., self._idx : self._idx + 1, :] = values
|
self.values[..., self._idx : self._idx + S, :] = values
|
||||||
self.offset += 1
|
self.offset += S
|
||||||
self._idx += 1
|
self._idx += S
|
||||||
|
|
||||||
# If the buffer is not full, slice off the end
|
# If the buffer is not full, slice off the end
|
||||||
if self.offset < self.max_size:
|
if self.offset < self.max_size:
|
||||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||||
return self.keys, self.values
|
return self.keys, self.values
|
||||||
|
|
||||||
|
def update_and_fetch(self, keys, values):
|
||||||
|
S = keys.shape[2]
|
||||||
|
if S == 1 or (self.keys is not None and S < (self.keys.shape[2] - self._idx)):
|
||||||
|
return self._update_in_place(keys, values)
|
||||||
|
|
||||||
|
return self._update_concat(keys, values)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state(self):
|
def state(self):
|
||||||
return self.keys, self.values
|
return self.keys, self.values
|
||||||
|
@@ -88,6 +88,46 @@ class TestModels(unittest.TestCase):
|
|||||||
if idx >= 8:
|
if idx >= 8:
|
||||||
idx = 2
|
idx = 2
|
||||||
|
|
||||||
|
def test_rotating_kv_cache_chat_mode(self):
|
||||||
|
# Test that the rotating kv cache can handle
|
||||||
|
# alternating prompt/prefill with generation
|
||||||
|
d = 4
|
||||||
|
h = 2
|
||||||
|
cache = RotatingKVCache(d, h, max_size=18, step=4)
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(1, h, 8, d))
|
||||||
|
k, v = cache.update_and_fetch(x, x)
|
||||||
|
self.assertEqual(k.shape[2], 8)
|
||||||
|
self.assertEqual(cache.offset, 8)
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(1, h, 1, d))
|
||||||
|
k, v = cache.update_and_fetch(x, x)
|
||||||
|
self.assertEqual(k.shape[2], 9)
|
||||||
|
self.assertEqual(cache.offset, 9)
|
||||||
|
self.assertTrue(mx.allclose(x, k[..., 8:9, :]))
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(1, h, 2, d))
|
||||||
|
k, v = cache.update_and_fetch(x, x)
|
||||||
|
self.assertEqual(k.shape[2], 11)
|
||||||
|
self.assertEqual(cache.offset, 11)
|
||||||
|
self.assertTrue(mx.allclose(x, k[..., 9:11, :]))
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(1, h, 3, d))
|
||||||
|
k, v = cache.update_and_fetch(x, x)
|
||||||
|
self.assertEqual(k.shape[2], 14)
|
||||||
|
self.assertEqual(cache.offset, 14)
|
||||||
|
self.assertTrue(mx.allclose(x, k[..., 11:14, :]))
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(1, h, 6, d))
|
||||||
|
k, v = cache.update_and_fetch(x, x)
|
||||||
|
self.assertEqual(cache.offset, 20)
|
||||||
|
self.assertTrue(mx.allclose(x, k[..., -6:, :]))
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(1, h, 2, d))
|
||||||
|
k, v = cache.update_and_fetch(x, x)
|
||||||
|
self.assertEqual(cache.offset, 22)
|
||||||
|
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
|
||||||
|
|
||||||
def model_test_runner(self, model, model_type, vocab_size, num_layers):
|
def model_test_runner(self, model, model_type, vocab_size, num_layers):
|
||||||
|
|
||||||
self.assertEqual(len(model.layers), num_layers)
|
self.assertEqual(len(model.layers), num_layers)
|
||||||
|
Reference in New Issue
Block a user