fixes + trim_prompt_cache api

This commit is contained in:
Awni Hannun 2024-10-07 16:50:03 -07:00
parent 7a3d0dd459
commit fbff8e2fd5
2 changed files with 120 additions and 5 deletions

View File

@ -46,7 +46,7 @@ def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str]
cache_info = [c.meta_state for c in cache]
cache_data = dict(tree_flatten(cache_data))
cache_classes = [type(c).__name__ for c in cache]
cache_metadata = [cache_classes, cache_info, metadata]
cache_metadata = [cache_info, metadata, cache_classes]
cache_metadata = dict(tree_flatten(cache_metadata))
mx.save_safetensors(file_name, cache_data, cache_metadata)
@ -67,7 +67,7 @@ def load_prompt_cache(file_name, return_metadata=False):
arrays, cache_metadata = mx.load(file_name, return_metadata=True)
arrays = tree_unflatten(list(arrays.items()))
cache_metadata = tree_unflatten(list(cache_metadata.items()))
classes, info, metadata = cache_metadata
info, metadata, classes = cache_metadata
cache = [globals()[c]() for c in classes]
for c, state, meta_state in zip(cache, arrays, info):
c.state = state
@ -77,6 +77,25 @@ def load_prompt_cache(file_name, return_metadata=False):
return cache
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
"""
Trim the model's cache by the given number of tokens.
This function will trim the cache if possible (in-place) and return the
number of tokens that were trimmed.
Args:
cache (List[Any]): The model's cache.
num_tokens (int): The number of tokens to trim.
Returns:
(int): The number of tokens that were trimmed.
"""
if not all(c.is_trimmable() for c in cache) or len(cache) == 0:
return 0
return [c.trim(num_tokens) for c in cache][0]
class _BaseCache:
@property
def state(self):
@ -89,16 +108,18 @@ class _BaseCache:
@property
def meta_state(self):
return {}
return ""
@state.setter
@meta_state.setter
def meta_state(self, v):
if v is not None and v:
raise ValueError("This cache has no meta_state but a meta_state was set.")
def is_trimmable(self):
return False
class KVCache(_BaseCache):
def __init__(self):
self.keys = None
self.values = None
@ -144,6 +165,14 @@ class KVCache(_BaseCache):
self.keys, self.values = v
self.offset = self.keys.shape[2]
def is_trimmable(self):
return True
def trim(self, n):
n = min(self.offset, n)
self.offset -= n
return n
class RotatingKVCache(_BaseCache):
@ -275,6 +304,15 @@ class RotatingKVCache(_BaseCache):
v,
)
def is_trimmable(self):
return self.offset < self.max_size
def trim(self, n):
n = min(self.offset, n)
self.offset -= n
self._idx -= n
return n
class MambaCache(_BaseCache):
def __init__(self):

View File

@ -12,6 +12,7 @@ from mlx_lm.models.cache import (
load_prompt_cache,
make_prompt_cache,
save_prompt_cache,
trim_prompt_cache,
)
from mlx_lm.utils import generate_step, load
@ -138,6 +139,82 @@ class TestPromptCache(unittest.TestCase):
self.assertEqual(tok, toks[i])
self.assertTrue(mx.allclose(logits, all_logits[i]))
def test_trim_cache(self):
cache = [KVCache() for _ in range(2)]
for c in cache:
x = mx.random.uniform(shape=(1, 8, 10, 4))
c.update_and_fetch(x, x)
# Trim
num_trimmed = trim_prompt_cache(cache, 7)
self.assertEqual(num_trimmed, 7)
# Trim more tokens than remain
num_trimmed = trim_prompt_cache(cache, 4)
self.assertEqual(num_trimmed, 3)
# Can't trim mamba cache
cache = [MambaCache() for _ in range(2)]
for c in cache:
c.state = mx.zeros((5, 5))
num_trimmed = trim_prompt_cache(cache, 7)
self.assertEqual(num_trimmed, 0)
# All cache's have to be trimmable
cache = [MambaCache(), KVCache()]
cache[0].state = mx.zeros((5, 5))
x = mx.random.uniform(shape=(1, 8, 10, 4))
cache[1].update_and_fetch(x, x)
num_trimmed = trim_prompt_cache(cache, 1)
self.assertEqual(num_trimmed, 0)
cache = [RotatingKVCache(max_size=6) for _ in range(2)]
for c in cache:
x = mx.random.uniform(shape=(1, 8, 5, 4))
c.update_and_fetch(x, x)
num_trimmed = trim_prompt_cache(cache, 4)
self.assertEqual(num_trimmed, 4)
# Can't trim fixed-size KV cache after processing
# more than max_kv_size tokens
for c in cache:
x = mx.random.uniform(shape=(1, 8, 10, 4))
c.update_and_fetch(x, x)
num_trimmed = trim_prompt_cache(cache, 4)
self.assertEqual(num_trimmed, 0)
def test_trim_cache_with_generate(self):
model, tokenizer = load(HF_MODEL_PATH)
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
prompt_cache = make_prompt_cache(model)
# Generate one token so we process the full prompt
last_tok, _ = next(generate_step(prompt, model, prompt_cache=prompt_cache))
last_tok = mx.array([last_tok])
# Generate two more tokens
results = zip(
range(2), generate_step(last_tok, model, prompt_cache=prompt_cache)
)
toks, all_logits = zip(*(r[1] for r in results))
# To get back to the cache just after processing the prompt,
# trim by 3 tokens
trim_prompt_cache(prompt_cache, 3)
# Generate the same thing again
results = zip(
range(2), generate_step(last_tok, model, prompt_cache=prompt_cache)
)
second_toks, second_all_logits = zip(*(r[1] for r in results))
self.assertEqual(toks, second_toks)
self.assertTrue(
all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits))
)
if __name__ == "__main__":
unittest.main()