From 29f21e7fe4dacc603e51e899426d38f33147c7b1 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Mon, 28 Oct 2024 22:14:52 -0700 Subject: [PATCH] add tests --- llms/mlx_lm/models/cache.py | 22 +++++++++--- llms/tests/test_prompt_cache.py | 63 +++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 4 deletions(-) diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index 0883e573..f4efe41e 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -238,6 +238,14 @@ class QuantizedKVCache(_BaseCache): def meta_state(self, v): self.step, self.offset, self.group_size, self.bits = map(int, v) + def is_trimmable(self): + return True + + def trim(self, n): + n = min(self.offset, n) + self.offset -= n + return n + class KVCache(_BaseCache): def __init__(self): @@ -296,8 +304,11 @@ class KVCache(_BaseCache): def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: quant_cache = QuantizedKVCache(group_size=group_size, bits=bits) quant_cache.offset = self.offset - quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits) - quant_cache.values = mx.quantize(self.values, group_size=group_size, bits=bits) + if self.keys is not None: + quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits) + quant_cache.values = mx.quantize( + self.values, group_size=group_size, bits=bits + ) return quant_cache @@ -443,8 +454,11 @@ class RotatingKVCache(_BaseCache): def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: quant_cache = QuantizedKVCache(group_size=group_size, bits=bits) quant_cache.offset = self.offset - quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits) - quant_cache.values = mx.quantize(self.values, group_size=group_size, bits=bits) + if self.keys is not None: + quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits) + quant_cache.values = mx.quantize( + self.values, group_size=group_size, bits=bits + ) return quant_cache diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py index 64cd9486..1e57bd86 100644 --- a/llms/tests/test_prompt_cache.py +++ b/llms/tests/test_prompt_cache.py @@ -9,6 +9,7 @@ import mlx.core as mx from mlx_lm.models.cache import ( KVCache, MambaCache, + QuantizedKVCache, RotatingKVCache, load_prompt_cache, make_prompt_cache, @@ -186,6 +187,18 @@ class TestPromptCache(unittest.TestCase): num_trimmed = trim_prompt_cache(cache, 4) self.assertEqual(num_trimmed, 0) + cache = [QuantizedKVCache() for _ in range(2)] + for c in cache: + x = mx.random.uniform(shape=(1, 8, 10, 64)) + c.update_and_fetch(x, x) + + num_trimmed = trim_prompt_cache(cache, 7) + self.assertEqual(num_trimmed, 7) + + # Trim more tokens than remain + num_trimmed = trim_prompt_cache(cache, 4) + self.assertEqual(num_trimmed, 3) + def test_trim_cache_with_generate(self): model, tokenizer = load(HF_MODEL_PATH) prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] @@ -238,6 +251,56 @@ class TestPromptCache(unittest.TestCase): self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y)) self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z)) + def test_save_load_quantized_cache(self): + cache = [QuantizedKVCache(bits=4, group_size=32) for _ in range(4)] + for c in cache: + x = mx.random.uniform(shape=(1, 8, 10, 32)) + c.update_and_fetch(x, x) + cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") + save_prompt_cache(cache_file, cache) + loaded_cache = load_prompt_cache(cache_file) + self.assertTrue(loaded_cache[0].bits == cache[0].bits) + self.assertTrue(loaded_cache[0].group_size == cache[0].group_size) + self.assertTrue(len(cache), len(loaded_cache)) + for c, lc in zip(cache, loaded_cache): + self.assertEqual(c.offset, lc.offset) + # Loop over quantized tuple + for i in range(3): + self.assertTrue(mx.array_equal(c.state[0][i], lc.state[0][i])) + self.assertTrue(mx.array_equal(c.state[1][i], lc.state[1][i])) + + # Test with metadata + cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") + metadata = {"a": "b", "c": "d"} + save_prompt_cache(cache_file, cache, metadata) + _, loaded_metadata = load_prompt_cache(cache_file, return_metadata=True) + self.assertEqual(metadata, loaded_metadata) + + def test_cache_to_quantized(self): + model, tokenizer = load(HF_MODEL_PATH) + prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] + results = zip(range(4), generate_step(prompt, model)) + toks, all_logits = zip(*(r[1] for r in results)) + + prompt_cache = make_prompt_cache(model) + i = 0 + for _, (tok, logits) in zip( + range(2), generate_step(prompt, model, prompt_cache=prompt_cache) + ): + self.assertEqual(tok, toks[i]) + self.assertTrue(mx.allclose(logits, all_logits[i])) + i += 1 + + prompt_cache = [c.to_quantized(bits=8, group_size=32) for c in prompt_cache] + + for _, (tok, logits) in zip( + range(1), + generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache), + ): + i += 1 + self.assertEqual(tok, toks[i]) + self.assertTrue(mx.allclose(logits, all_logits[i], rtol=1e-2)) + if __name__ == "__main__": unittest.main()