mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
add tests
This commit is contained in:
parent
37a3723823
commit
29f21e7fe4
@ -238,6 +238,14 @@ class QuantizedKVCache(_BaseCache):
|
|||||||
def meta_state(self, v):
|
def meta_state(self, v):
|
||||||
self.step, self.offset, self.group_size, self.bits = map(int, v)
|
self.step, self.offset, self.group_size, self.bits = map(int, v)
|
||||||
|
|
||||||
|
def is_trimmable(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def trim(self, n):
|
||||||
|
n = min(self.offset, n)
|
||||||
|
self.offset -= n
|
||||||
|
return n
|
||||||
|
|
||||||
|
|
||||||
class KVCache(_BaseCache):
|
class KVCache(_BaseCache):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -296,8 +304,11 @@ class KVCache(_BaseCache):
|
|||||||
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
|
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
|
||||||
quant_cache = QuantizedKVCache(group_size=group_size, bits=bits)
|
quant_cache = QuantizedKVCache(group_size=group_size, bits=bits)
|
||||||
quant_cache.offset = self.offset
|
quant_cache.offset = self.offset
|
||||||
quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits)
|
if self.keys is not None:
|
||||||
quant_cache.values = mx.quantize(self.values, group_size=group_size, bits=bits)
|
quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits)
|
||||||
|
quant_cache.values = mx.quantize(
|
||||||
|
self.values, group_size=group_size, bits=bits
|
||||||
|
)
|
||||||
return quant_cache
|
return quant_cache
|
||||||
|
|
||||||
|
|
||||||
@ -443,8 +454,11 @@ class RotatingKVCache(_BaseCache):
|
|||||||
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
|
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
|
||||||
quant_cache = QuantizedKVCache(group_size=group_size, bits=bits)
|
quant_cache = QuantizedKVCache(group_size=group_size, bits=bits)
|
||||||
quant_cache.offset = self.offset
|
quant_cache.offset = self.offset
|
||||||
quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits)
|
if self.keys is not None:
|
||||||
quant_cache.values = mx.quantize(self.values, group_size=group_size, bits=bits)
|
quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits)
|
||||||
|
quant_cache.values = mx.quantize(
|
||||||
|
self.values, group_size=group_size, bits=bits
|
||||||
|
)
|
||||||
return quant_cache
|
return quant_cache
|
||||||
|
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@ import mlx.core as mx
|
|||||||
from mlx_lm.models.cache import (
|
from mlx_lm.models.cache import (
|
||||||
KVCache,
|
KVCache,
|
||||||
MambaCache,
|
MambaCache,
|
||||||
|
QuantizedKVCache,
|
||||||
RotatingKVCache,
|
RotatingKVCache,
|
||||||
load_prompt_cache,
|
load_prompt_cache,
|
||||||
make_prompt_cache,
|
make_prompt_cache,
|
||||||
@ -186,6 +187,18 @@ class TestPromptCache(unittest.TestCase):
|
|||||||
num_trimmed = trim_prompt_cache(cache, 4)
|
num_trimmed = trim_prompt_cache(cache, 4)
|
||||||
self.assertEqual(num_trimmed, 0)
|
self.assertEqual(num_trimmed, 0)
|
||||||
|
|
||||||
|
cache = [QuantizedKVCache() for _ in range(2)]
|
||||||
|
for c in cache:
|
||||||
|
x = mx.random.uniform(shape=(1, 8, 10, 64))
|
||||||
|
c.update_and_fetch(x, x)
|
||||||
|
|
||||||
|
num_trimmed = trim_prompt_cache(cache, 7)
|
||||||
|
self.assertEqual(num_trimmed, 7)
|
||||||
|
|
||||||
|
# Trim more tokens than remain
|
||||||
|
num_trimmed = trim_prompt_cache(cache, 4)
|
||||||
|
self.assertEqual(num_trimmed, 3)
|
||||||
|
|
||||||
def test_trim_cache_with_generate(self):
|
def test_trim_cache_with_generate(self):
|
||||||
model, tokenizer = load(HF_MODEL_PATH)
|
model, tokenizer = load(HF_MODEL_PATH)
|
||||||
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
|
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
|
||||||
@ -238,6 +251,56 @@ class TestPromptCache(unittest.TestCase):
|
|||||||
self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y))
|
self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y))
|
||||||
self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z))
|
self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z))
|
||||||
|
|
||||||
|
def test_save_load_quantized_cache(self):
|
||||||
|
cache = [QuantizedKVCache(bits=4, group_size=32) for _ in range(4)]
|
||||||
|
for c in cache:
|
||||||
|
x = mx.random.uniform(shape=(1, 8, 10, 32))
|
||||||
|
c.update_and_fetch(x, x)
|
||||||
|
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
|
||||||
|
save_prompt_cache(cache_file, cache)
|
||||||
|
loaded_cache = load_prompt_cache(cache_file)
|
||||||
|
self.assertTrue(loaded_cache[0].bits == cache[0].bits)
|
||||||
|
self.assertTrue(loaded_cache[0].group_size == cache[0].group_size)
|
||||||
|
self.assertTrue(len(cache), len(loaded_cache))
|
||||||
|
for c, lc in zip(cache, loaded_cache):
|
||||||
|
self.assertEqual(c.offset, lc.offset)
|
||||||
|
# Loop over quantized tuple
|
||||||
|
for i in range(3):
|
||||||
|
self.assertTrue(mx.array_equal(c.state[0][i], lc.state[0][i]))
|
||||||
|
self.assertTrue(mx.array_equal(c.state[1][i], lc.state[1][i]))
|
||||||
|
|
||||||
|
# Test with metadata
|
||||||
|
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
|
||||||
|
metadata = {"a": "b", "c": "d"}
|
||||||
|
save_prompt_cache(cache_file, cache, metadata)
|
||||||
|
_, loaded_metadata = load_prompt_cache(cache_file, return_metadata=True)
|
||||||
|
self.assertEqual(metadata, loaded_metadata)
|
||||||
|
|
||||||
|
def test_cache_to_quantized(self):
|
||||||
|
model, tokenizer = load(HF_MODEL_PATH)
|
||||||
|
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
|
||||||
|
results = zip(range(4), generate_step(prompt, model))
|
||||||
|
toks, all_logits = zip(*(r[1] for r in results))
|
||||||
|
|
||||||
|
prompt_cache = make_prompt_cache(model)
|
||||||
|
i = 0
|
||||||
|
for _, (tok, logits) in zip(
|
||||||
|
range(2), generate_step(prompt, model, prompt_cache=prompt_cache)
|
||||||
|
):
|
||||||
|
self.assertEqual(tok, toks[i])
|
||||||
|
self.assertTrue(mx.allclose(logits, all_logits[i]))
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
prompt_cache = [c.to_quantized(bits=8, group_size=32) for c in prompt_cache]
|
||||||
|
|
||||||
|
for _, (tok, logits) in zip(
|
||||||
|
range(1),
|
||||||
|
generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache),
|
||||||
|
):
|
||||||
|
i += 1
|
||||||
|
self.assertEqual(tok, toks[i])
|
||||||
|
self.assertTrue(mx.allclose(logits, all_logits[i], rtol=1e-2))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user