Generation refactor: part 2 (#1099)

* unify with stream_generate

* fixes

* nit

* some cleanup, warnings, tests

* fix test + faster min p + test

* version
This commit is contained in:
Awni Hannun
2024-11-23 11:47:06 -08:00
committed by GitHub
parent 004eb4cc9d
commit 0f135396ae
13 changed files with 184 additions and 197 deletions

View File

@@ -8,6 +8,7 @@ import json
import logging
import shutil
import time
from dataclasses import dataclass
from pathlib import Path
from textwrap import dedent
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
@@ -44,6 +45,32 @@ class ModelNotFoundError(Exception):
super().__init__(self.message)
@dataclass
class GenerationResponse:
"""
The output of :func:`stream_generate`.
Args:
text (str): The next segment of decoded text. This can be an empty string.
token (int): The next token.
logprobs (mx.array): A vector of log probabilities.
prompt_tokens (int): The number of tokens in the prompt.
prompt_tps (float): The prompt processing tokens-per-second.
generation_tokens (int): The number of generated tokens.
generation_tps (float): The tokens-per-second for generation.
peak_memory (float): The peak memory used so far in GB.
"""
text: str
token: int
logprobs: mx.array
prompt_tokens: int
prompt_tps: float
generation_tokens: int
generation_tps: float
peak_memory: float
@contextlib.contextmanager
def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
"""
@@ -155,20 +182,21 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
def generate_step(
prompt: mx.array,
model: nn.Module,
temp: float = 0.0,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20,
top_p: float = 1.0,
min_p: float = 0.0,
min_tokens_to_keep: int = 1,
prefill_step_size: int = 512,
*,
sampler: Optional[Callable[mx.array, mx.array]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
max_kv_size: Optional[int] = None,
prompt_cache: Optional[Any] = None,
logit_bias: Optional[Dict[int, float]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
prefill_step_size: int = 512,
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
temp: Optional[float] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = None,
top_p: Optional[float] = None,
min_p: Optional[float] = None,
min_tokens_to_keep: Optional[int] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
@@ -176,32 +204,21 @@ def generate_step(
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling, if 0 the argmax is used.
Default: ``0``.
repetition_penalty (float, optional): The penalty factor for repeating
tokens.
repetition_context_size (int, optional): The number of tokens to
consider for repetition penalty. Default: ``20``.
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
min_p (float, optional): The minimum value (scaled by the top token's
probability) that a token probability must have to be considered.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling.
prefill_step_size (int): Step size for processing the prompt.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place.
logit_bias (dictionary, optional): Additive logit bias.
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
token from a vector of log probabilities. Default: ``None``.
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
kv_bits (int, optional): Number of bits to use for KV cache quantization.
None implies no cache quantization. Default: ``None``.
None implies no cache quantization. Default: ``None``.
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``.
when ``kv_bits`` is non-None. Default: ``0``.
Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
@@ -219,10 +236,22 @@ def generate_step(
elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.")
sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep)
logits_processors = logits_processors or []
logits_processors.extend(
make_logits_processors(logit_bias, repetition_penalty, repetition_context_size)
if temp is not None or top_p is not None or min_tokens_to_keep is not None:
print(
"[Warning] Specifying sampling arguments to ``generate_step`` is "
"deprecated. Pass in a ``sampler`` instead."
)
if repetition_penalty is not None:
print(
"[Warning] Specifying ``repetition_penalty`` is deprecated. "
"Pass in ``logits_processors`` instead."
)
sampler = sampler or make_sampler(
temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1
)
logits_processors = logits_processors or make_logits_processors(
None, repetition_penalty, repetition_context_size or 20
)
def _step(y):
@@ -290,17 +319,20 @@ def stream_generate(
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
prompt_tokens = mx.array(
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
)
prompt = mx.array(prompt if isinstance(prompt, list) else tokenizer.encode(prompt))
detokenizer = tokenizer.detokenizer
with wired_limit(model, [generation_stream]):
detokenizer.reset()
for n, (token, logits) in zip(
tic = time.perf_counter()
for n, (token, logprobs) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
generate_step(prompt, model, **kwargs),
):
if n == 0:
prompt_time = time.perf_counter() - tic
prompt_tps = prompt.size / prompt_time
tic = time.perf_counter()
if token == tokenizer.eos_token_id:
break
@@ -309,17 +341,34 @@ def stream_generate(
if n == (max_tokens - 1):
break
yield detokenizer.last_segment, token, logits
yield GenerationResponse(
text=detokenizer.last_segment,
token=token,
logprobs=logprobs,
prompt_tokens=prompt.size,
prompt_tps=prompt_tps,
generation_tokens=n + 1,
generation_tps=(n + 1) / (time.perf_counter() - tic),
peak_memory=mx.metal.get_peak_memory() / 1e9,
)
detokenizer.finalize()
yield detokenizer.last_segment, token, logits
yield GenerationResponse(
text=detokenizer.last_segment,
token=token,
logprobs=logprobs,
prompt_tokens=prompt.size,
prompt_tps=prompt_tps,
generation_tokens=n + 1,
generation_tps=(n + 1) / (time.perf_counter() - tic),
peak_memory=mx.metal.get_peak_memory() / 1e9,
)
def generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str,
max_tokens: int = 100,
verbose: bool = False,
formatter: Optional[Callable] = None,
**kwargs,
@@ -334,64 +383,40 @@ def generate(
max_tokens (int): The maximum number of tokens. Default: ``100``.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
formatter (Optional[Callable]): A function which takes a token and a
probability and displays it.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
kwargs: The remaining options get passed to :func:`stream_generate`.
See :func:`stream_generate` for more details.
"""
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
if formatter is not None:
print(
"[Warning] Text formatting is deprecated and no longer used. "
"The argument will be removed in a future version."
)
if verbose:
print("=" * 10)
print("Prompt:", prompt)
prompt_tokens = mx.array(tokenizer.encode(prompt))
detokenizer = tokenizer.detokenizer
with wired_limit(model, [generation_stream]):
tic = time.perf_counter()
detokenizer.reset()
for n, (token, logprobs) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
):
if n == 0:
prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
if token == tokenizer.eos_token_id:
break
detokenizer.add_token(token)
if verbose:
if formatter:
# We have to finalize so that the prob corresponds to the last segment
detokenizer.finalize()
prob = mx.exp(logprobs[token]).item()
formatter(detokenizer.last_segment, prob)
else:
print(detokenizer.last_segment, end="", flush=True)
token_count = n + 1
detokenizer.finalize()
text = ""
for response in stream_generate(model, tokenizer, prompt, **kwargs):
if verbose:
gen_time = time.perf_counter() - tic
print(detokenizer.last_segment, flush=True)
print("=" * 10)
if token_count == 0:
print("No tokens generated for this prompt")
return
prompt_tps = prompt_tokens.size / prompt_time
gen_tps = (token_count - 1) / gen_time
print(
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec"
)
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
peak_mem = mx.metal.get_peak_memory() / 1e9
print(f"Peak memory: {peak_mem:.3f} GB")
print(response.text, end="", flush=True)
text += response.text
return detokenizer.text
if verbose:
print()
print("=" * 10)
if len(text) == 0:
print("No text generated for this prompt")
return
print(
f"Prompt: {response.prompt_tokens} tokens, "
f"{response.prompt_tps:.3f} tokens-per-sec"
)
print(
f"Generation: {response.generation_tokens} tokens, "
f"{response.generation_tps:.3f} tokens-per-sec"
)
print(f"Peak memory: {response.peak_memory:.3f} GB")
return text
def load_config(model_path: Path) -> dict: