reorg + fixes to caching, unify prompt caching across types and use cases for e.g. caching during a chat

This commit is contained in:
Awni Hannun
2024-10-05 14:49:39 -07:00
parent ed060a7c5c
commit 782f5a71b7
40 changed files with 824 additions and 691 deletions

View File

@@ -18,7 +18,7 @@ from mlx.utils import tree_flatten
from transformers import PreTrainedTokenizer
# Local imports
from .models.base import KVCache, RotatingKVCache
from .models import base, cache
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import dequantize as dequantize_model
@@ -124,26 +124,6 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float)
return logits
def make_kv_caches(
model: nn.Module, max_kv_size: Optional[int] = None
) -> List[Union[KVCache, RotatingKVCache]]:
if hasattr(model, "make_cache"):
return model.make_cache()
kv_heads = (
[model.n_kv_heads] * len(model.layers)
if isinstance(model.n_kv_heads, int)
else model.n_kv_heads
)
if max_kv_size is not None:
return [
RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
for n in kv_heads
]
else:
return [KVCache(model.head_dim, n) for n in kv_heads]
def generate_step(
prompt: mx.array,
model: nn.Module,
@@ -155,7 +135,7 @@ def generate_step(
min_tokens_to_keep: int = 1,
prefill_step_size: int = 512,
max_kv_size: Optional[int] = None,
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
prompt_cache: Optional[Any] = None,
logit_bias: Optional[Dict[int, float]] = None,
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
@@ -180,6 +160,8 @@ def generate_step(
prefill_step_size (int): Step size for processing the prompt.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place.
logit_bias (dictionary, optional): Additive logit bias.
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
@@ -237,20 +219,13 @@ def generate_step(
tokens = None
# Create the KV cache for generation
cache = make_kv_caches(model, max_kv_size)
if cache_history is not None:
if len(cache_history) != len(cache):
raise ValueError("Wrong number of layers in the cache history")
# Set the history in the cache objects and evaluate them to prepare for
# generation.
for c, h in zip(cache, cache_history):
c.update_and_fetch(h[0], h[1])
mx.eval([c.state for c in cache])
if prompt_cache is None:
prompt_cache = cache.make_prompt_cache(model, max_kv_size)
elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.")
def _step(y):
logits = model(y[None], cache=cache)
logits = model(y[None], cache=prompt_cache)
logits = logits[:, -1, :]
if logits_processor:
@@ -265,7 +240,7 @@ def generate_step(
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=cache)
mx.eval([c.state for c in cache])
mx.eval([c.state[0] for c in cache])
y = y[prefill_step_size:]
y, logprobs = _step(y)
@@ -305,9 +280,9 @@ def stream_generate(
detokenizer = tokenizer.detokenizer
detokenizer.reset()
for (token, _), n in zip(
generate_step(prompt_tokens, model, **kwargs),
for n, (token, _) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
):
if token == tokenizer.eos_token_id:
break
@@ -357,9 +332,9 @@ def generate(
tic = time.perf_counter()
detokenizer.reset()
for (token, logprobs), n in zip(
generate_step(prompt_tokens, model, **kwargs),
for n, (token, logprobs) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
):
if n == 0:
prompt_time = time.perf_counter() - tic