unify with stream_generate

This commit is contained in:
Awni Hannun 2024-11-07 17:21:15 -08:00
parent 004eb4cc9d
commit 431988721f
3 changed files with 98 additions and 122 deletions

View File

@ -97,11 +97,6 @@ def setup_arg_parser():
default=True, default=True,
help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'", help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'",
) )
parser.add_argument(
"--colorize",
action="store_true",
help="Colorize output based on T[0] probability",
)
parser.add_argument( parser.add_argument(
"--max-kv-size", "--max-kv-size",
type=int, type=int,
@ -137,33 +132,6 @@ def setup_arg_parser():
return parser return parser
def colorprint(color, s):
color_codes = {
"black": 30,
"red": 31,
"green": 32,
"yellow": 33,
"blue": 34,
"magenta": 35,
"cyan": 36,
"white": 39,
}
ccode = color_codes.get(color, 30)
print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True)
def colorprint_by_t0(s, t0):
if t0 > 0.95:
color = "white"
elif t0 > 0.70:
color = "green"
elif t0 > 0.30:
color = "yellow"
else:
color = "red"
colorprint(color, s)
def main(): def main():
parser = setup_arg_parser() parser = setup_arg_parser()
args = parser.parse_args() args = parser.parse_args()
@ -250,17 +218,12 @@ def main():
else: else:
prompt = args.prompt prompt = args.prompt
if args.colorize and not args.verbose:
raise ValueError("Cannot use --colorize with --verbose=False")
formatter = colorprint_by_t0 if args.colorize else None
response = generate( response = generate(
model, model,
tokenizer, tokenizer,
prompt, prompt,
args.max_tokens, max_tokens=args.max_tokens,
verbose=args.verbose, verbose=args.verbose,
formatter=formatter,
temp=args.temp, temp=args.temp,
top_p=args.top_p, top_p=args.top_p,
min_p=args.min_p, min_p=args.min_p,

View File

@ -464,25 +464,21 @@ class APIHandler(BaseHTTPRequestHandler):
text = "" text = ""
tic = time.perf_counter() tic = time.perf_counter()
for n, (segment, token, logprobs) in enumerate( for gen_response in stream_generate(
stream_generate( model=self.model,
model=self.model, tokenizer=self.tokenizer,
tokenizer=self.tokenizer, prompt=prompt,
prompt=prompt, max_tokens=self.max_tokens,
max_tokens=self.max_tokens, temp=self.temperature,
temp=self.temperature, repetition_penalty=self.repetition_penalty,
repetition_penalty=self.repetition_penalty, repetition_context_size=self.repetition_context_size,
repetition_context_size=self.repetition_context_size, logit_bias=self.logit_bias,
logit_bias=self.logit_bias, prompt_cache=self.prompt_cache.cache,
prompt_cache=self.prompt_cache.cache,
),
): ):
if n == 0: text += gen_response.text
prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
text += segment
logging.debug(text) logging.debug(text)
token = gen_response.token
logprobs = gen_response.logprobs
tokens.append(token) tokens.append(token)
if self.logprobs > 0: if self.logprobs > 0:
@ -523,13 +519,9 @@ class APIHandler(BaseHTTPRequestHandler):
self.prompt_cache.tokens.extend(tokens) self.prompt_cache.tokens.extend(tokens)
gen_time = time.perf_counter() - tic logging.debug(f"Prompt: {gen_response.prompt_tps:.3f} tokens-per-sec")
prompt_tps = len(prompt) / prompt_time logging.debug(f"Generation: {gen_response.generation_tps:.3f} tokens-per-sec")
gen_tps = len(tokens) / gen_time logging.debug(f"Peak memory: {gen_response.peak_memory:.3f} GB")
peak_mem = mx.metal.get_peak_memory() / 1e9
logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec")
logging.debug(f"Peak memory: {peak_mem:.3f} GB")
if self.stream: if self.stream:
response = self.generate_response(segment, finish_reason) response = self.generate_response(segment, finish_reason)

View File

@ -8,6 +8,7 @@ import json
import logging import logging
import shutil import shutil
import time import time
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from textwrap import dedent from textwrap import dedent
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
@ -44,6 +45,32 @@ class ModelNotFoundError(Exception):
super().__init__(self.message) super().__init__(self.message)
@dataclass
class GenerationResponse:
"""
The output of :func:`stream_generate`.
Args:
text (str): The next segment of decoded text. This can be an empty string.
token (int): The next token.
logprobs (mx.array): A vector of log probabilities.
prompt_tokens (int): The number of tokens in the prompt.
prompt_tps (float): The prompt processing tokens-per-second.
generation_tokens (int): The number of generated tokens.
generation_tps (float): The tokens-per-second for generation.
peak_memory (float): The peak memory used so far in GB.
"""
text: str
token: int
logprobs: mx.array
prompt_tokens: int
prompt_tps: float
generation_tokens: int
generation_tps: float
peak_memory: float
@contextlib.contextmanager @contextlib.contextmanager
def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None): def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
""" """
@ -290,17 +317,20 @@ def stream_generate(
if not isinstance(tokenizer, TokenizerWrapper): if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer) tokenizer = TokenizerWrapper(tokenizer)
prompt_tokens = mx.array( prompt = mx.array(prompt if isinstance(prompt, list) else tokenizer.encode(prompt))
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
)
detokenizer = tokenizer.detokenizer detokenizer = tokenizer.detokenizer
with wired_limit(model, [generation_stream]): with wired_limit(model, [generation_stream]):
detokenizer.reset() detokenizer.reset()
for n, (token, logits) in zip( tic = time.perf_counter()
for n, (token, logprobs) in zip(
range(max_tokens), range(max_tokens),
generate_step(prompt_tokens, model, **kwargs), generate_step(prompt, model, **kwargs),
): ):
if n == 0:
prompt_time = time.perf_counter() - tic
prompt_tps = prompt.size / prompt_time
tic = time.perf_counter()
if token == tokenizer.eos_token_id: if token == tokenizer.eos_token_id:
break break
@ -309,17 +339,34 @@ def stream_generate(
if n == (max_tokens - 1): if n == (max_tokens - 1):
break break
yield detokenizer.last_segment, token, logits yield GenerationResponse(
text=detokenizer.last_segment,
token=token,
logprobs=logprobs,
prompt_tokens=prompt.size,
prompt_tps=prompt_tps,
generation_tokens=n + 1,
generation_tps=(n + 1) / (time.perf_counter() - tic),
peak_memory=mx.metal.get_peak_memory() / 1e9,
)
detokenizer.finalize() detokenizer.finalize()
yield detokenizer.last_segment, token, logits yield GenerationResponse(
text=detokenizer.last_segment,
token=token,
logprobs=logprobs,
prompt_tokens=prompt.size,
prompt_tps=prompt_tps,
generation_tokens=n + 1,
generation_tps=(n + 1) / (time.perf_counter() - tic),
peak_memory=mx.metal.get_peak_memory() / 1e9,
)
def generate( def generate(
model: nn.Module, model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str, prompt: str,
max_tokens: int = 100,
verbose: bool = False, verbose: bool = False,
formatter: Optional[Callable] = None, formatter: Optional[Callable] = None,
**kwargs, **kwargs,
@ -331,67 +378,41 @@ def generate(
model (nn.Module): The language model. model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer. tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt. prompt (str): The string prompt.
max_tokens (int): The maximum number of tokens. Default: ``100``.
verbose (bool): If ``True``, print tokens and timing information. verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``. Default: ``False``.
formatter (Optional[Callable]): A function which takes a token and a kwargs: The remaining options get passed to :func:`stream_generate`.
probability and displays it. See :func:`stream_generate` for more details.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
""" """
if not isinstance(tokenizer, TokenizerWrapper): if formatter is not None:
tokenizer = TokenizerWrapper(tokenizer) print(
"Text formatting has been deprecated and will be removed in the next version."
)
if verbose: if verbose:
print("=" * 10) print("=" * 10)
print("Prompt:", prompt) print("Prompt:", prompt)
prompt_tokens = mx.array(tokenizer.encode(prompt)) full_text = ""
detokenizer = tokenizer.detokenizer for response in stream_generate(model, tokenizer, prompt, **kwargs):
with wired_limit(model, [generation_stream]):
tic = time.perf_counter()
detokenizer.reset()
for n, (token, logprobs) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
):
if n == 0:
prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
if token == tokenizer.eos_token_id:
break
detokenizer.add_token(token)
if verbose:
if formatter:
# We have to finalize so that the prob corresponds to the last segment
detokenizer.finalize()
prob = mx.exp(logprobs[token]).item()
formatter(detokenizer.last_segment, prob)
else:
print(detokenizer.last_segment, end="", flush=True)
token_count = n + 1
detokenizer.finalize()
if verbose: if verbose:
gen_time = time.perf_counter() - tic print(response.text, end="", flush=True)
print(detokenizer.last_segment, flush=True) full_text += response.text
print("=" * 10)
if token_count == 0:
print("No tokens generated for this prompt")
return
prompt_tps = prompt_tokens.size / prompt_time
gen_tps = (token_count - 1) / gen_time
print(
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec"
)
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
peak_mem = mx.metal.get_peak_memory() / 1e9
print(f"Peak memory: {peak_mem:.3f} GB")
return detokenizer.text if verbose:
print()
print("=" * 10)
if len(full_text) == 0:
print("No text generated for this prompt")
return
print(
f"Prompt: {response.prompt_tokens} tokens, "
f"{response.prompt_tps:.3f} tokens-per-sec"
)
print(
f"Generation: {response.generation_tokens} tokens, "
f"{response.generation_tps:.3f} tokens-per-sec"
)
print(f"Peak memory: {response.peak_memory:.3f} GB")
return full_text
def load_config(model_path: Path) -> dict: def load_config(model_path: Path) -> dict: