Support non incremental kv cache growth (#766)

This commit is contained in:
Awni Hannun
2024-05-15 12:56:24 -07:00
committed by GitHub
parent 1a86d985d9
commit 69181e0058
2 changed files with 24 additions and 1 deletions

View File

@@ -9,6 +9,26 @@ from mlx_lm.models.base import KVCache
class TestModels(unittest.TestCase):
def test_kv_cache(self):
cache = KVCache(32, 4)
k = mx.ones((1, 4, 1, 32), mx.float16)
v = mx.ones((1, 4, 1, 32), mx.float16)
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up, k))
self.assertTrue(mx.array_equal(v_up, v))
self.assertEqual(cache.offset, 1)
k = mx.ones((1, 4, cache.step, 32), mx.float16)
v = mx.ones((1, 4, cache.step, 32), mx.float16)
k_up, v_up = cache.update_and_fetch(k, v)
expected = mx.ones((1, 4, cache.step + 1, 32), mx.float16)
self.assertTrue(mx.array_equal(k_up, expected))
self.assertTrue(mx.array_equal(v_up, expected))
self.assertEqual(cache.offset, cache.step + 1)
def model_test_runner(self, model, model_type, vocab_size, num_layers):
self.assertEqual(len(model.layers), num_layers)