mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Generation refactor: part 2 (#1099)
* unify with stream_generate * fixes * nit * some cleanup, warnings, tests * fix test + faster min p + test * version
This commit is contained in:
@@ -8,6 +8,7 @@ import json
|
||||
import logging
|
||||
import shutil
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
|
||||
@@ -44,6 +45,32 @@ class ModelNotFoundError(Exception):
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationResponse:
|
||||
"""
|
||||
The output of :func:`stream_generate`.
|
||||
|
||||
Args:
|
||||
text (str): The next segment of decoded text. This can be an empty string.
|
||||
token (int): The next token.
|
||||
logprobs (mx.array): A vector of log probabilities.
|
||||
prompt_tokens (int): The number of tokens in the prompt.
|
||||
prompt_tps (float): The prompt processing tokens-per-second.
|
||||
generation_tokens (int): The number of generated tokens.
|
||||
generation_tps (float): The tokens-per-second for generation.
|
||||
peak_memory (float): The peak memory used so far in GB.
|
||||
"""
|
||||
|
||||
text: str
|
||||
token: int
|
||||
logprobs: mx.array
|
||||
prompt_tokens: int
|
||||
prompt_tps: float
|
||||
generation_tokens: int
|
||||
generation_tps: float
|
||||
peak_memory: float
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
|
||||
"""
|
||||
@@ -155,20 +182,21 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
|
||||
def generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
temp: float = 0.0,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_context_size: Optional[int] = 20,
|
||||
top_p: float = 1.0,
|
||||
min_p: float = 0.0,
|
||||
min_tokens_to_keep: int = 1,
|
||||
prefill_step_size: int = 512,
|
||||
*,
|
||||
sampler: Optional[Callable[mx.array, mx.array]] = None,
|
||||
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
max_kv_size: Optional[int] = None,
|
||||
prompt_cache: Optional[Any] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
prefill_step_size: int = 512,
|
||||
kv_bits: Optional[int] = None,
|
||||
kv_group_size: int = 64,
|
||||
quantized_kv_start: int = 0,
|
||||
temp: Optional[float] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_context_size: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
min_p: Optional[float] = None,
|
||||
min_tokens_to_keep: Optional[int] = None,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
"""
|
||||
A generator producing token ids based on the given prompt from the model.
|
||||
@@ -176,32 +204,21 @@ def generate_step(
|
||||
Args:
|
||||
prompt (mx.array): The input prompt.
|
||||
model (nn.Module): The model to use for generation.
|
||||
temp (float): The temperature for sampling, if 0 the argmax is used.
|
||||
Default: ``0``.
|
||||
repetition_penalty (float, optional): The penalty factor for repeating
|
||||
tokens.
|
||||
repetition_context_size (int, optional): The number of tokens to
|
||||
consider for repetition penalty. Default: ``20``.
|
||||
top_p (float, optional): Nulceus sampling, higher means model considers
|
||||
more less likely words.
|
||||
min_p (float, optional): The minimum value (scaled by the top token's
|
||||
probability) that a token probability must have to be considered.
|
||||
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
||||
be filtered by min_p sampling.
|
||||
prefill_step_size (int): Step size for processing the prompt.
|
||||
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
||||
entries (except the first 4 tokens) will be overwritten.
|
||||
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||
provided, the cache will be updated in place.
|
||||
logit_bias (dictionary, optional): Additive logit bias.
|
||||
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
|
||||
token from a vector of log probabilities. Default: ``None``.
|
||||
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||
A list of functions that take tokens and logits and return the processed
|
||||
logits. Default: ``None``.
|
||||
A list of functions that take tokens and logits and return the processed
|
||||
logits. Default: ``None``.
|
||||
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
||||
None implies no cache quantization. Default: ``None``.
|
||||
None implies no cache quantization. Default: ``None``.
|
||||
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
||||
quantized_kv_start (int): Step to begin using a quantized KV cache.
|
||||
when ``kv_bits`` is non-None. Default: ``0``.
|
||||
when ``kv_bits`` is non-None. Default: ``0``.
|
||||
|
||||
Yields:
|
||||
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
||||
@@ -219,10 +236,22 @@ def generate_step(
|
||||
elif len(prompt_cache) != len(model.layers):
|
||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||
|
||||
sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep)
|
||||
logits_processors = logits_processors or []
|
||||
logits_processors.extend(
|
||||
make_logits_processors(logit_bias, repetition_penalty, repetition_context_size)
|
||||
if temp is not None or top_p is not None or min_tokens_to_keep is not None:
|
||||
print(
|
||||
"[Warning] Specifying sampling arguments to ``generate_step`` is "
|
||||
"deprecated. Pass in a ``sampler`` instead."
|
||||
)
|
||||
if repetition_penalty is not None:
|
||||
print(
|
||||
"[Warning] Specifying ``repetition_penalty`` is deprecated. "
|
||||
"Pass in ``logits_processors`` instead."
|
||||
)
|
||||
|
||||
sampler = sampler or make_sampler(
|
||||
temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1
|
||||
)
|
||||
logits_processors = logits_processors or make_logits_processors(
|
||||
None, repetition_penalty, repetition_context_size or 20
|
||||
)
|
||||
|
||||
def _step(y):
|
||||
@@ -290,17 +319,20 @@ def stream_generate(
|
||||
if not isinstance(tokenizer, TokenizerWrapper):
|
||||
tokenizer = TokenizerWrapper(tokenizer)
|
||||
|
||||
prompt_tokens = mx.array(
|
||||
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
|
||||
)
|
||||
prompt = mx.array(prompt if isinstance(prompt, list) else tokenizer.encode(prompt))
|
||||
detokenizer = tokenizer.detokenizer
|
||||
|
||||
with wired_limit(model, [generation_stream]):
|
||||
detokenizer.reset()
|
||||
for n, (token, logits) in zip(
|
||||
tic = time.perf_counter()
|
||||
for n, (token, logprobs) in zip(
|
||||
range(max_tokens),
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
generate_step(prompt, model, **kwargs),
|
||||
):
|
||||
if n == 0:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
prompt_tps = prompt.size / prompt_time
|
||||
tic = time.perf_counter()
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
@@ -309,17 +341,34 @@ def stream_generate(
|
||||
if n == (max_tokens - 1):
|
||||
break
|
||||
|
||||
yield detokenizer.last_segment, token, logits
|
||||
yield GenerationResponse(
|
||||
text=detokenizer.last_segment,
|
||||
token=token,
|
||||
logprobs=logprobs,
|
||||
prompt_tokens=prompt.size,
|
||||
prompt_tps=prompt_tps,
|
||||
generation_tokens=n + 1,
|
||||
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
||||
peak_memory=mx.metal.get_peak_memory() / 1e9,
|
||||
)
|
||||
|
||||
detokenizer.finalize()
|
||||
yield detokenizer.last_segment, token, logits
|
||||
yield GenerationResponse(
|
||||
text=detokenizer.last_segment,
|
||||
token=token,
|
||||
logprobs=logprobs,
|
||||
prompt_tokens=prompt.size,
|
||||
prompt_tps=prompt_tps,
|
||||
generation_tokens=n + 1,
|
||||
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
||||
peak_memory=mx.metal.get_peak_memory() / 1e9,
|
||||
)
|
||||
|
||||
|
||||
def generate(
|
||||
model: nn.Module,
|
||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||
prompt: str,
|
||||
max_tokens: int = 100,
|
||||
verbose: bool = False,
|
||||
formatter: Optional[Callable] = None,
|
||||
**kwargs,
|
||||
@@ -334,64 +383,40 @@ def generate(
|
||||
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
||||
verbose (bool): If ``True``, print tokens and timing information.
|
||||
Default: ``False``.
|
||||
formatter (Optional[Callable]): A function which takes a token and a
|
||||
probability and displays it.
|
||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||
See :func:`generate_step` for more details.
|
||||
kwargs: The remaining options get passed to :func:`stream_generate`.
|
||||
See :func:`stream_generate` for more details.
|
||||
"""
|
||||
if not isinstance(tokenizer, TokenizerWrapper):
|
||||
tokenizer = TokenizerWrapper(tokenizer)
|
||||
|
||||
if formatter is not None:
|
||||
print(
|
||||
"[Warning] Text formatting is deprecated and no longer used. "
|
||||
"The argument will be removed in a future version."
|
||||
)
|
||||
if verbose:
|
||||
print("=" * 10)
|
||||
print("Prompt:", prompt)
|
||||
|
||||
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
||||
detokenizer = tokenizer.detokenizer
|
||||
|
||||
with wired_limit(model, [generation_stream]):
|
||||
tic = time.perf_counter()
|
||||
detokenizer.reset()
|
||||
for n, (token, logprobs) in zip(
|
||||
range(max_tokens),
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
):
|
||||
if n == 0:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
tic = time.perf_counter()
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
detokenizer.add_token(token)
|
||||
|
||||
if verbose:
|
||||
if formatter:
|
||||
# We have to finalize so that the prob corresponds to the last segment
|
||||
detokenizer.finalize()
|
||||
prob = mx.exp(logprobs[token]).item()
|
||||
formatter(detokenizer.last_segment, prob)
|
||||
else:
|
||||
print(detokenizer.last_segment, end="", flush=True)
|
||||
|
||||
token_count = n + 1
|
||||
detokenizer.finalize()
|
||||
|
||||
text = ""
|
||||
for response in stream_generate(model, tokenizer, prompt, **kwargs):
|
||||
if verbose:
|
||||
gen_time = time.perf_counter() - tic
|
||||
print(detokenizer.last_segment, flush=True)
|
||||
print("=" * 10)
|
||||
if token_count == 0:
|
||||
print("No tokens generated for this prompt")
|
||||
return
|
||||
prompt_tps = prompt_tokens.size / prompt_time
|
||||
gen_tps = (token_count - 1) / gen_time
|
||||
print(
|
||||
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec"
|
||||
)
|
||||
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
||||
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||
print(f"Peak memory: {peak_mem:.3f} GB")
|
||||
print(response.text, end="", flush=True)
|
||||
text += response.text
|
||||
|
||||
return detokenizer.text
|
||||
if verbose:
|
||||
print()
|
||||
print("=" * 10)
|
||||
if len(text) == 0:
|
||||
print("No text generated for this prompt")
|
||||
return
|
||||
print(
|
||||
f"Prompt: {response.prompt_tokens} tokens, "
|
||||
f"{response.prompt_tps:.3f} tokens-per-sec"
|
||||
)
|
||||
print(
|
||||
f"Generation: {response.generation_tokens} tokens, "
|
||||
f"{response.generation_tps:.3f} tokens-per-sec"
|
||||
)
|
||||
print(f"Peak memory: {response.peak_memory:.3f} GB")
|
||||
return text
|
||||
|
||||
|
||||
def load_config(model_path: Path) -> dict:
|
||||
|
Reference in New Issue
Block a user