mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Handle longer prompt/generation (#931)
* rebase * nits * nit * fix rotating cache with step prefill * update version
This commit is contained in:
@@ -19,7 +19,7 @@ from mlx.utils import tree_flatten
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
from .models.base import KVCache
|
||||
from .models.base import KVCache, RotatingKVCache
|
||||
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
|
||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||
from .tuner.utils import apply_lora_layers
|
||||
@@ -136,6 +136,8 @@ def generate_step(
|
||||
min_p: float = 0.0,
|
||||
min_tokens_to_keep: int = 1,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
prefill_step_size: int = 512,
|
||||
max_kv_size: Optional[int] = None,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
"""
|
||||
A generator producing token ids based on the given prompt from the model.
|
||||
@@ -156,6 +158,9 @@ def generate_step(
|
||||
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
||||
be filtered by min_p sampling.
|
||||
logit_bias (dictionary, optional): Additive logit bias.
|
||||
prefill_step_size (int): Step size for processing the prompt.
|
||||
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
||||
entries (except the first 4 tokens) will be overwritten.
|
||||
|
||||
Yields:
|
||||
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
||||
@@ -197,7 +202,13 @@ def generate_step(
|
||||
if isinstance(model.n_kv_heads, int)
|
||||
else model.n_kv_heads
|
||||
)
|
||||
cache = [KVCache(model.head_dim, n) for n in kv_heads]
|
||||
if max_kv_size is not None:
|
||||
cache = [
|
||||
RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
|
||||
for n in kv_heads
|
||||
]
|
||||
else:
|
||||
cache = [KVCache(model.head_dim, n) for n in kv_heads]
|
||||
|
||||
repetition_context = prompt.tolist()
|
||||
|
||||
@@ -223,6 +234,11 @@ def generate_step(
|
||||
repetition_context = repetition_context[-repetition_context_size:]
|
||||
return y, logprobs.squeeze(0)
|
||||
|
||||
while y.size > prefill_step_size:
|
||||
model(y[:prefill_step_size][None], cache=cache)
|
||||
mx.eval([c.state for c in cache])
|
||||
y = y[prefill_step_size:]
|
||||
|
||||
y, logprobs = _step(y)
|
||||
|
||||
mx.async_eval(y)
|
||||
@@ -343,8 +359,10 @@ def generate(
|
||||
return
|
||||
prompt_tps = prompt_tokens.size / prompt_time
|
||||
gen_tps = (token_count - 1) / gen_time
|
||||
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||
print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec")
|
||||
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||
print(f"Peak memory: {peak_mem:.3f} GB")
|
||||
|
||||
return detokenizer.text
|
||||
|
||||
|
Reference in New Issue
Block a user