From 431988721f2cad897e76074d6e8df412f16faea0 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 7 Nov 2024 17:21:15 -0800 Subject: [PATCH] unify with stream_generate --- llms/mlx_lm/generate.py | 39 +---------- llms/mlx_lm/server.py | 40 +++++------- llms/mlx_lm/utils.py | 141 +++++++++++++++++++++++----------------- 3 files changed, 98 insertions(+), 122 deletions(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index de5c5719..09849632 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -97,11 +97,6 @@ def setup_arg_parser(): default=True, help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'", ) - parser.add_argument( - "--colorize", - action="store_true", - help="Colorize output based on T[0] probability", - ) parser.add_argument( "--max-kv-size", type=int, @@ -137,33 +132,6 @@ def setup_arg_parser(): return parser -def colorprint(color, s): - color_codes = { - "black": 30, - "red": 31, - "green": 32, - "yellow": 33, - "blue": 34, - "magenta": 35, - "cyan": 36, - "white": 39, - } - ccode = color_codes.get(color, 30) - print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True) - - -def colorprint_by_t0(s, t0): - if t0 > 0.95: - color = "white" - elif t0 > 0.70: - color = "green" - elif t0 > 0.30: - color = "yellow" - else: - color = "red" - colorprint(color, s) - - def main(): parser = setup_arg_parser() args = parser.parse_args() @@ -250,17 +218,12 @@ def main(): else: prompt = args.prompt - if args.colorize and not args.verbose: - raise ValueError("Cannot use --colorize with --verbose=False") - formatter = colorprint_by_t0 if args.colorize else None - response = generate( model, tokenizer, prompt, - args.max_tokens, + max_tokens=args.max_tokens, verbose=args.verbose, - formatter=formatter, temp=args.temp, top_p=args.top_p, min_p=args.min_p, diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index c1365b36..a71c305e 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -464,25 +464,21 @@ class APIHandler(BaseHTTPRequestHandler): text = "" tic = time.perf_counter() - for n, (segment, token, logprobs) in enumerate( - stream_generate( - model=self.model, - tokenizer=self.tokenizer, - prompt=prompt, - max_tokens=self.max_tokens, - temp=self.temperature, - repetition_penalty=self.repetition_penalty, - repetition_context_size=self.repetition_context_size, - logit_bias=self.logit_bias, - prompt_cache=self.prompt_cache.cache, - ), + for gen_response in stream_generate( + model=self.model, + tokenizer=self.tokenizer, + prompt=prompt, + max_tokens=self.max_tokens, + temp=self.temperature, + repetition_penalty=self.repetition_penalty, + repetition_context_size=self.repetition_context_size, + logit_bias=self.logit_bias, + prompt_cache=self.prompt_cache.cache, ): - if n == 0: - prompt_time = time.perf_counter() - tic - tic = time.perf_counter() - - text += segment + text += gen_response.text logging.debug(text) + token = gen_response.token + logprobs = gen_response.logprobs tokens.append(token) if self.logprobs > 0: @@ -523,13 +519,9 @@ class APIHandler(BaseHTTPRequestHandler): self.prompt_cache.tokens.extend(tokens) - gen_time = time.perf_counter() - tic - prompt_tps = len(prompt) / prompt_time - gen_tps = len(tokens) / gen_time - peak_mem = mx.metal.get_peak_memory() / 1e9 - logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec") - logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec") - logging.debug(f"Peak memory: {peak_mem:.3f} GB") + logging.debug(f"Prompt: {gen_response.prompt_tps:.3f} tokens-per-sec") + logging.debug(f"Generation: {gen_response.generation_tps:.3f} tokens-per-sec") + logging.debug(f"Peak memory: {gen_response.peak_memory:.3f} GB") if self.stream: response = self.generate_response(segment, finish_reason) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index d4afd428..33f32cc8 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -8,6 +8,7 @@ import json import logging import shutil import time +from dataclasses import dataclass from pathlib import Path from textwrap import dedent from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union @@ -44,6 +45,32 @@ class ModelNotFoundError(Exception): super().__init__(self.message) +@dataclass +class GenerationResponse: + """ + The output of :func:`stream_generate`. + + Args: + text (str): The next segment of decoded text. This can be an empty string. + token (int): The next token. + logprobs (mx.array): A vector of log probabilities. + prompt_tokens (int): The number of tokens in the prompt. + prompt_tps (float): The prompt processing tokens-per-second. + generation_tokens (int): The number of generated tokens. + generation_tps (float): The tokens-per-second for generation. + peak_memory (float): The peak memory used so far in GB. + """ + + text: str + token: int + logprobs: mx.array + prompt_tokens: int + prompt_tps: float + generation_tokens: int + generation_tps: float + peak_memory: float + + @contextlib.contextmanager def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None): """ @@ -290,17 +317,20 @@ def stream_generate( if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) - prompt_tokens = mx.array( - prompt if isinstance(prompt, list) else tokenizer.encode(prompt) - ) + prompt = mx.array(prompt if isinstance(prompt, list) else tokenizer.encode(prompt)) detokenizer = tokenizer.detokenizer with wired_limit(model, [generation_stream]): detokenizer.reset() - for n, (token, logits) in zip( + tic = time.perf_counter() + for n, (token, logprobs) in zip( range(max_tokens), - generate_step(prompt_tokens, model, **kwargs), + generate_step(prompt, model, **kwargs), ): + if n == 0: + prompt_time = time.perf_counter() - tic + prompt_tps = prompt.size / prompt_time + tic = time.perf_counter() if token == tokenizer.eos_token_id: break @@ -309,17 +339,34 @@ def stream_generate( if n == (max_tokens - 1): break - yield detokenizer.last_segment, token, logits + yield GenerationResponse( + text=detokenizer.last_segment, + token=token, + logprobs=logprobs, + prompt_tokens=prompt.size, + prompt_tps=prompt_tps, + generation_tokens=n + 1, + generation_tps=(n + 1) / (time.perf_counter() - tic), + peak_memory=mx.metal.get_peak_memory() / 1e9, + ) detokenizer.finalize() - yield detokenizer.last_segment, token, logits + yield GenerationResponse( + text=detokenizer.last_segment, + token=token, + logprobs=logprobs, + prompt_tokens=prompt.size, + prompt_tps=prompt_tps, + generation_tokens=n + 1, + generation_tps=(n + 1) / (time.perf_counter() - tic), + peak_memory=mx.metal.get_peak_memory() / 1e9, + ) def generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], prompt: str, - max_tokens: int = 100, verbose: bool = False, formatter: Optional[Callable] = None, **kwargs, @@ -331,67 +378,41 @@ def generate( model (nn.Module): The language model. tokenizer (PreTrainedTokenizer): The tokenizer. prompt (str): The string prompt. - max_tokens (int): The maximum number of tokens. Default: ``100``. verbose (bool): If ``True``, print tokens and timing information. Default: ``False``. - formatter (Optional[Callable]): A function which takes a token and a - probability and displays it. - kwargs: The remaining options get passed to :func:`generate_step`. - See :func:`generate_step` for more details. + kwargs: The remaining options get passed to :func:`stream_generate`. + See :func:`stream_generate` for more details. """ - if not isinstance(tokenizer, TokenizerWrapper): - tokenizer = TokenizerWrapper(tokenizer) - + if formatter is not None: + print( + "Text formatting has been deprecated and will be removed in the next version." + ) if verbose: print("=" * 10) print("Prompt:", prompt) - prompt_tokens = mx.array(tokenizer.encode(prompt)) - detokenizer = tokenizer.detokenizer - - with wired_limit(model, [generation_stream]): - tic = time.perf_counter() - detokenizer.reset() - for n, (token, logprobs) in zip( - range(max_tokens), - generate_step(prompt_tokens, model, **kwargs), - ): - if n == 0: - prompt_time = time.perf_counter() - tic - tic = time.perf_counter() - if token == tokenizer.eos_token_id: - break - detokenizer.add_token(token) - - if verbose: - if formatter: - # We have to finalize so that the prob corresponds to the last segment - detokenizer.finalize() - prob = mx.exp(logprobs[token]).item() - formatter(detokenizer.last_segment, prob) - else: - print(detokenizer.last_segment, end="", flush=True) - - token_count = n + 1 - detokenizer.finalize() - + full_text = "" + for response in stream_generate(model, tokenizer, prompt, **kwargs): if verbose: - gen_time = time.perf_counter() - tic - print(detokenizer.last_segment, flush=True) - print("=" * 10) - if token_count == 0: - print("No tokens generated for this prompt") - return - prompt_tps = prompt_tokens.size / prompt_time - gen_tps = (token_count - 1) / gen_time - print( - f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec" - ) - print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") - peak_mem = mx.metal.get_peak_memory() / 1e9 - print(f"Peak memory: {peak_mem:.3f} GB") + print(response.text, end="", flush=True) + full_text += response.text - return detokenizer.text + if verbose: + print() + print("=" * 10) + if len(full_text) == 0: + print("No text generated for this prompt") + return + print( + f"Prompt: {response.prompt_tokens} tokens, " + f"{response.prompt_tps:.3f} tokens-per-sec" + ) + print( + f"Generation: {response.generation_tokens} tokens, " + f"{response.generation_tps:.3f} tokens-per-sec" + ) + print(f"Peak memory: {response.peak_memory:.3f} GB") + return full_text def load_config(model_path: Path) -> dict: