Handle longer prompt/generation (#931)

* rebase

* nits

* nit

* fix rotating cache with step prefill

* update version
This commit is contained in:
Awni Hannun
2024-08-16 15:28:39 -07:00
committed by GitHub
parent e196fa3208
commit 7be292c0c9
32 changed files with 255 additions and 13 deletions

View File

@@ -19,7 +19,7 @@ from mlx.utils import tree_flatten
from transformers import PreTrainedTokenizer
# Local imports
from .models.base import KVCache
from .models.base import KVCache, RotatingKVCache
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import apply_lora_layers
@@ -136,6 +136,8 @@ def generate_step(
min_p: float = 0.0,
min_tokens_to_keep: int = 1,
logit_bias: Optional[Dict[int, float]] = None,
prefill_step_size: int = 512,
max_kv_size: Optional[int] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
@@ -156,6 +158,9 @@ def generate_step(
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling.
logit_bias (dictionary, optional): Additive logit bias.
prefill_step_size (int): Step size for processing the prompt.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
Yields:
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
@@ -197,7 +202,13 @@ def generate_step(
if isinstance(model.n_kv_heads, int)
else model.n_kv_heads
)
cache = [KVCache(model.head_dim, n) for n in kv_heads]
if max_kv_size is not None:
cache = [
RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
for n in kv_heads
]
else:
cache = [KVCache(model.head_dim, n) for n in kv_heads]
repetition_context = prompt.tolist()
@@ -223,6 +234,11 @@ def generate_step(
repetition_context = repetition_context[-repetition_context_size:]
return y, logprobs.squeeze(0)
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=cache)
mx.eval([c.state for c in cache])
y = y[prefill_step_size:]
y, logprobs = _step(y)
mx.async_eval(y)
@@ -343,8 +359,10 @@ def generate(
return
prompt_tps = prompt_tokens.size / prompt_time
gen_tps = (token_count - 1) / gen_time
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
peak_mem = mx.metal.get_peak_memory() / 2**30
print(f"Peak memory: {peak_mem:.3f} GB")
return detokenizer.text