mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Handle longer prompt/generation (#931)
* rebase * nits * nit * fix rotating cache with step prefill * update version
This commit is contained in:
@@ -4,7 +4,7 @@ import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_map
|
||||
from mlx_lm.models.base import KVCache
|
||||
from mlx_lm.models.base import KVCache, RotatingKVCache
|
||||
|
||||
|
||||
class TestModels(unittest.TestCase):
|
||||
@@ -29,6 +29,64 @@ class TestModels(unittest.TestCase):
|
||||
self.assertTrue(mx.array_equal(v_up, expected))
|
||||
self.assertEqual(cache.offset, cache.step + 1)
|
||||
|
||||
def test_rotating_kv_cache(self):
|
||||
b, h, d = 1, 2, 32
|
||||
cache = RotatingKVCache(d, h, max_size=8, step=4)
|
||||
|
||||
k = mx.random.uniform(shape=(b, h, 2, d))
|
||||
v = mx.random.uniform(shape=(b, h, 2, d))
|
||||
|
||||
k_up, v_up = cache.update_and_fetch(k, v)
|
||||
self.assertTrue(mx.array_equal(k_up, k))
|
||||
self.assertTrue(mx.array_equal(v_up, v))
|
||||
self.assertEqual(cache.offset, 2)
|
||||
|
||||
k = mx.random.uniform(shape=(b, h, 5, d))
|
||||
v = mx.random.uniform(shape=(b, h, 5, d))
|
||||
k_up, v_up = cache.update_and_fetch(k, v)
|
||||
self.assertTrue(mx.array_equal(k_up[..., 2:, :], k))
|
||||
self.assertTrue(mx.array_equal(v_up[..., 2:, :], v))
|
||||
|
||||
k = mx.random.uniform(shape=(b, h, 4, d))
|
||||
v = mx.random.uniform(shape=(b, h, 4, d))
|
||||
k_up, v_up = cache.update_and_fetch(k, v)
|
||||
self.assertTrue(mx.array_equal(k_up[..., -4:, :], k))
|
||||
self.assertTrue(mx.array_equal(v_up[..., -4:, :], v))
|
||||
|
||||
idx = 0
|
||||
for _ in range(10):
|
||||
k = mx.random.uniform(shape=(b, h, 1, d))
|
||||
v = mx.random.uniform(shape=(b, h, 1, d))
|
||||
k_up, v_up = cache.update_and_fetch(k, v)
|
||||
self.assertTrue(mx.array_equal(k_up[..., idx : idx + 1, :], k))
|
||||
self.assertTrue(mx.array_equal(v_up[..., idx : idx + 1, :], v))
|
||||
idx += 1
|
||||
idx %= 8
|
||||
|
||||
# Try with nonzero keep
|
||||
cache = RotatingKVCache(d, h, max_size=8, step=4, keep=2)
|
||||
|
||||
# Check a large update
|
||||
k = mx.random.uniform(shape=(b, h, 20, d))
|
||||
v = mx.random.uniform(shape=(b, h, 20, d))
|
||||
k_up, v_up = cache.update_and_fetch(k, v)
|
||||
self.assertTrue(mx.array_equal(k_up, k))
|
||||
self.assertTrue(mx.array_equal(v_up, v))
|
||||
|
||||
# A bunch of small updates
|
||||
self.assertEqual(cache.offset, 20)
|
||||
idx = 2
|
||||
for i in range(10):
|
||||
k = mx.random.uniform(shape=(b, h, 1, d))
|
||||
v = mx.random.uniform(shape=(b, h, 1, d))
|
||||
k_up, v_up = cache.update_and_fetch(k, v)
|
||||
self.assertTrue(mx.array_equal(k_up[..., idx : idx + 1, :], k))
|
||||
self.assertTrue(mx.array_equal(v_up[..., idx : idx + 1, :], v))
|
||||
self.assertEqual(cache.offset, 21 + i)
|
||||
idx += 1
|
||||
if idx >= 8:
|
||||
idx = 2
|
||||
|
||||
def model_test_runner(self, model, model_type, vocab_size, num_layers):
|
||||
|
||||
self.assertEqual(len(model.layers), num_layers)
|
||||
|
Reference in New Issue
Block a user