mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
fixes + trim_prompt_cache api
This commit is contained in:
@@ -46,7 +46,7 @@ def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str]
|
|||||||
cache_info = [c.meta_state for c in cache]
|
cache_info = [c.meta_state for c in cache]
|
||||||
cache_data = dict(tree_flatten(cache_data))
|
cache_data = dict(tree_flatten(cache_data))
|
||||||
cache_classes = [type(c).__name__ for c in cache]
|
cache_classes = [type(c).__name__ for c in cache]
|
||||||
cache_metadata = [cache_classes, cache_info, metadata]
|
cache_metadata = [cache_info, metadata, cache_classes]
|
||||||
cache_metadata = dict(tree_flatten(cache_metadata))
|
cache_metadata = dict(tree_flatten(cache_metadata))
|
||||||
mx.save_safetensors(file_name, cache_data, cache_metadata)
|
mx.save_safetensors(file_name, cache_data, cache_metadata)
|
||||||
|
|
||||||
@@ -67,7 +67,7 @@ def load_prompt_cache(file_name, return_metadata=False):
|
|||||||
arrays, cache_metadata = mx.load(file_name, return_metadata=True)
|
arrays, cache_metadata = mx.load(file_name, return_metadata=True)
|
||||||
arrays = tree_unflatten(list(arrays.items()))
|
arrays = tree_unflatten(list(arrays.items()))
|
||||||
cache_metadata = tree_unflatten(list(cache_metadata.items()))
|
cache_metadata = tree_unflatten(list(cache_metadata.items()))
|
||||||
classes, info, metadata = cache_metadata
|
info, metadata, classes = cache_metadata
|
||||||
cache = [globals()[c]() for c in classes]
|
cache = [globals()[c]() for c in classes]
|
||||||
for c, state, meta_state in zip(cache, arrays, info):
|
for c, state, meta_state in zip(cache, arrays, info):
|
||||||
c.state = state
|
c.state = state
|
||||||
@@ -77,6 +77,25 @@ def load_prompt_cache(file_name, return_metadata=False):
|
|||||||
return cache
|
return cache
|
||||||
|
|
||||||
|
|
||||||
|
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
|
||||||
|
"""
|
||||||
|
Trim the model's cache by the given number of tokens.
|
||||||
|
|
||||||
|
This function will trim the cache if possible (in-place) and return the
|
||||||
|
number of tokens that were trimmed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache (List[Any]): The model's cache.
|
||||||
|
num_tokens (int): The number of tokens to trim.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(int): The number of tokens that were trimmed.
|
||||||
|
"""
|
||||||
|
if not all(c.is_trimmable() for c in cache) or len(cache) == 0:
|
||||||
|
return 0
|
||||||
|
return [c.trim(num_tokens) for c in cache][0]
|
||||||
|
|
||||||
|
|
||||||
class _BaseCache:
|
class _BaseCache:
|
||||||
@property
|
@property
|
||||||
def state(self):
|
def state(self):
|
||||||
@@ -89,16 +108,18 @@ class _BaseCache:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def meta_state(self):
|
def meta_state(self):
|
||||||
return {}
|
return ""
|
||||||
|
|
||||||
@state.setter
|
@meta_state.setter
|
||||||
def meta_state(self, v):
|
def meta_state(self, v):
|
||||||
if v is not None and v:
|
if v is not None and v:
|
||||||
raise ValueError("This cache has no meta_state but a meta_state was set.")
|
raise ValueError("This cache has no meta_state but a meta_state was set.")
|
||||||
|
|
||||||
|
def is_trimmable(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class KVCache(_BaseCache):
|
class KVCache(_BaseCache):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.keys = None
|
self.keys = None
|
||||||
self.values = None
|
self.values = None
|
||||||
@@ -144,6 +165,14 @@ class KVCache(_BaseCache):
|
|||||||
self.keys, self.values = v
|
self.keys, self.values = v
|
||||||
self.offset = self.keys.shape[2]
|
self.offset = self.keys.shape[2]
|
||||||
|
|
||||||
|
def is_trimmable(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def trim(self, n):
|
||||||
|
n = min(self.offset, n)
|
||||||
|
self.offset -= n
|
||||||
|
return n
|
||||||
|
|
||||||
|
|
||||||
class RotatingKVCache(_BaseCache):
|
class RotatingKVCache(_BaseCache):
|
||||||
|
|
||||||
@@ -275,6 +304,15 @@ class RotatingKVCache(_BaseCache):
|
|||||||
v,
|
v,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def is_trimmable(self):
|
||||||
|
return self.offset < self.max_size
|
||||||
|
|
||||||
|
def trim(self, n):
|
||||||
|
n = min(self.offset, n)
|
||||||
|
self.offset -= n
|
||||||
|
self._idx -= n
|
||||||
|
return n
|
||||||
|
|
||||||
|
|
||||||
class MambaCache(_BaseCache):
|
class MambaCache(_BaseCache):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@@ -12,6 +12,7 @@ from mlx_lm.models.cache import (
|
|||||||
load_prompt_cache,
|
load_prompt_cache,
|
||||||
make_prompt_cache,
|
make_prompt_cache,
|
||||||
save_prompt_cache,
|
save_prompt_cache,
|
||||||
|
trim_prompt_cache,
|
||||||
)
|
)
|
||||||
from mlx_lm.utils import generate_step, load
|
from mlx_lm.utils import generate_step, load
|
||||||
|
|
||||||
@@ -138,6 +139,82 @@ class TestPromptCache(unittest.TestCase):
|
|||||||
self.assertEqual(tok, toks[i])
|
self.assertEqual(tok, toks[i])
|
||||||
self.assertTrue(mx.allclose(logits, all_logits[i]))
|
self.assertTrue(mx.allclose(logits, all_logits[i]))
|
||||||
|
|
||||||
|
def test_trim_cache(self):
|
||||||
|
cache = [KVCache() for _ in range(2)]
|
||||||
|
for c in cache:
|
||||||
|
x = mx.random.uniform(shape=(1, 8, 10, 4))
|
||||||
|
c.update_and_fetch(x, x)
|
||||||
|
|
||||||
|
# Trim
|
||||||
|
num_trimmed = trim_prompt_cache(cache, 7)
|
||||||
|
self.assertEqual(num_trimmed, 7)
|
||||||
|
|
||||||
|
# Trim more tokens than remain
|
||||||
|
num_trimmed = trim_prompt_cache(cache, 4)
|
||||||
|
self.assertEqual(num_trimmed, 3)
|
||||||
|
|
||||||
|
# Can't trim mamba cache
|
||||||
|
cache = [MambaCache() for _ in range(2)]
|
||||||
|
for c in cache:
|
||||||
|
c.state = mx.zeros((5, 5))
|
||||||
|
num_trimmed = trim_prompt_cache(cache, 7)
|
||||||
|
self.assertEqual(num_trimmed, 0)
|
||||||
|
|
||||||
|
# All cache's have to be trimmable
|
||||||
|
cache = [MambaCache(), KVCache()]
|
||||||
|
cache[0].state = mx.zeros((5, 5))
|
||||||
|
x = mx.random.uniform(shape=(1, 8, 10, 4))
|
||||||
|
cache[1].update_and_fetch(x, x)
|
||||||
|
num_trimmed = trim_prompt_cache(cache, 1)
|
||||||
|
self.assertEqual(num_trimmed, 0)
|
||||||
|
|
||||||
|
cache = [RotatingKVCache(max_size=6) for _ in range(2)]
|
||||||
|
for c in cache:
|
||||||
|
x = mx.random.uniform(shape=(1, 8, 5, 4))
|
||||||
|
c.update_and_fetch(x, x)
|
||||||
|
|
||||||
|
num_trimmed = trim_prompt_cache(cache, 4)
|
||||||
|
self.assertEqual(num_trimmed, 4)
|
||||||
|
|
||||||
|
# Can't trim fixed-size KV cache after processing
|
||||||
|
# more than max_kv_size tokens
|
||||||
|
for c in cache:
|
||||||
|
x = mx.random.uniform(shape=(1, 8, 10, 4))
|
||||||
|
c.update_and_fetch(x, x)
|
||||||
|
|
||||||
|
num_trimmed = trim_prompt_cache(cache, 4)
|
||||||
|
self.assertEqual(num_trimmed, 0)
|
||||||
|
|
||||||
|
def test_trim_cache_with_generate(self):
|
||||||
|
model, tokenizer = load(HF_MODEL_PATH)
|
||||||
|
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
|
||||||
|
|
||||||
|
prompt_cache = make_prompt_cache(model)
|
||||||
|
|
||||||
|
# Generate one token so we process the full prompt
|
||||||
|
last_tok, _ = next(generate_step(prompt, model, prompt_cache=prompt_cache))
|
||||||
|
last_tok = mx.array([last_tok])
|
||||||
|
|
||||||
|
# Generate two more tokens
|
||||||
|
results = zip(
|
||||||
|
range(2), generate_step(last_tok, model, prompt_cache=prompt_cache)
|
||||||
|
)
|
||||||
|
toks, all_logits = zip(*(r[1] for r in results))
|
||||||
|
|
||||||
|
# To get back to the cache just after processing the prompt,
|
||||||
|
# trim by 3 tokens
|
||||||
|
trim_prompt_cache(prompt_cache, 3)
|
||||||
|
|
||||||
|
# Generate the same thing again
|
||||||
|
results = zip(
|
||||||
|
range(2), generate_step(last_tok, model, prompt_cache=prompt_cache)
|
||||||
|
)
|
||||||
|
second_toks, second_all_logits = zip(*(r[1] for r in results))
|
||||||
|
self.assertEqual(toks, second_toks)
|
||||||
|
self.assertTrue(
|
||||||
|
all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Reference in New Issue
Block a user