mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Support non incremental kv cache growth (#766)
This commit is contained in:
parent
1a86d985d9
commit
69181e0058
@ -16,12 +16,15 @@ class KVCache:
|
|||||||
|
|
||||||
def update_and_fetch(self, keys, values):
|
def update_and_fetch(self, keys, values):
|
||||||
prev = self.offset
|
prev = self.offset
|
||||||
if prev % self.step == 0:
|
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
|
||||||
n_steps = (self.step + keys.shape[2] - 1) // self.step
|
n_steps = (self.step + keys.shape[2] - 1) // self.step
|
||||||
shape = (1, self.n_kv_heads, n_steps * self.step, self.head_dim)
|
shape = (1, self.n_kv_heads, n_steps * self.step, self.head_dim)
|
||||||
new_k = mx.zeros(shape, keys.dtype)
|
new_k = mx.zeros(shape, keys.dtype)
|
||||||
new_v = mx.zeros(shape, values.dtype)
|
new_v = mx.zeros(shape, values.dtype)
|
||||||
if self.keys is not None:
|
if self.keys is not None:
|
||||||
|
if prev % self.step != 0:
|
||||||
|
self.keys = self.keys[..., :prev, :]
|
||||||
|
self.values = self.values[..., :prev, :]
|
||||||
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
||||||
self.values = mx.concatenate([self.values, new_v], axis=2)
|
self.values = mx.concatenate([self.values, new_v], axis=2)
|
||||||
else:
|
else:
|
||||||
|
@ -9,6 +9,26 @@ from mlx_lm.models.base import KVCache
|
|||||||
|
|
||||||
class TestModels(unittest.TestCase):
|
class TestModels(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_kv_cache(self):
|
||||||
|
cache = KVCache(32, 4)
|
||||||
|
|
||||||
|
k = mx.ones((1, 4, 1, 32), mx.float16)
|
||||||
|
v = mx.ones((1, 4, 1, 32), mx.float16)
|
||||||
|
|
||||||
|
k_up, v_up = cache.update_and_fetch(k, v)
|
||||||
|
self.assertTrue(mx.array_equal(k_up, k))
|
||||||
|
self.assertTrue(mx.array_equal(v_up, v))
|
||||||
|
self.assertEqual(cache.offset, 1)
|
||||||
|
|
||||||
|
k = mx.ones((1, 4, cache.step, 32), mx.float16)
|
||||||
|
v = mx.ones((1, 4, cache.step, 32), mx.float16)
|
||||||
|
k_up, v_up = cache.update_and_fetch(k, v)
|
||||||
|
|
||||||
|
expected = mx.ones((1, 4, cache.step + 1, 32), mx.float16)
|
||||||
|
self.assertTrue(mx.array_equal(k_up, expected))
|
||||||
|
self.assertTrue(mx.array_equal(v_up, expected))
|
||||||
|
self.assertEqual(cache.offset, cache.step + 1)
|
||||||
|
|
||||||
def model_test_runner(self, model, model_type, vocab_size, num_layers):
|
def model_test_runner(self, model, model_type, vocab_size, num_layers):
|
||||||
|
|
||||||
self.assertEqual(len(model.layers), num_layers)
|
self.assertEqual(len(model.layers), num_layers)
|
||||||
|
Loading…
Reference in New Issue
Block a user