From 40b61c1719d83bf5199929da648f4c7e43bbb4a7 Mon Sep 17 00:00:00 2001 From: iLoveBug <40977151+iLoveBug@users.noreply.github.com> Date: Wed, 24 Jan 2024 04:44:23 +0800 Subject: [PATCH] fix the chinese character generation as same as PR #321 (#342) * fix the chinese character generation as same as PR #321 * reuse the generate logic to utils.py * format * verbose defualt * fix conflicst with colorize and character check --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/generate.py | 71 ++++++++++++----------------------------- llms/mlx_lm/utils.py | 58 +++++++++++++++++++++++---------- 2 files changed, 63 insertions(+), 66 deletions(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index a6e89afa..57f080e2 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -1,9 +1,8 @@ import argparse -import time import mlx.core as mx -from .utils import generate_step, load +from .utils import generate, load DEFAULT_MODEL_PATH = "mlx_model" DEFAULT_PROMPT = "hello" @@ -53,7 +52,7 @@ def setup_arg_parser(): ) parser.add_argument( "--colorize", - action='store_true', + action="store_true", help="Colorize output based on T[0] probability", ) return parser @@ -61,29 +60,29 @@ def setup_arg_parser(): def colorprint(color, s): color_codes = { - 'black': 30, - 'red': 31, - 'green': 32, - 'yellow': 33, - 'blue': 34, - 'magenta': 35, - 'cyan': 36, - 'white': 39, + "black": 30, + "red": 31, + "green": 32, + "yellow": 33, + "blue": 34, + "magenta": 35, + "cyan": 36, + "white": 39, } ccode = color_codes.get(color, 30) print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True) -def colorprint_by_t0(t0, s): +def colorprint_by_t0(s, t0): if t0 > 0.95: - color = 'white' + color = "white" elif t0 > 0.70: - color = 'green' + color = "green" elif t0 > 0.30: - color = 'yellow' + color = "yellow" else: - color = 'red' - colorprint(color,s) + color = "red" + colorprint(color, s) def main(args): @@ -107,39 +106,11 @@ def main(args): else: prompt = args.prompt - print("=" * 10) - print("Prompt:", prompt) - prompt = tokenizer.encode(prompt) - prompt = mx.array(prompt) - tic = time.time() - tokens = [] - skip = 0 - for token, n in zip( - generate_step(prompt, model, args.temp, args.colorize), range(args.max_tokens) - ): - token, t0 = token - if token == tokenizer.eos_token_id: - break - if n == 0: - prompt_time = time.time() - tic - tic = time.time() - tokens.append(token.item()) - s = tokenizer.decode(tokens) - if args.colorize: - colorprint_by_t0(t0,s[skip:]) - else: - print(s[skip:], end="", flush=True) - skip = len(s) - print(tokenizer.decode(tokens)[skip:], flush=True) - gen_time = time.time() - tic - print("=" * 10) - if len(tokens) == 0: - print("No tokens generated for this prompt") - return - prompt_tps = prompt.size / prompt_time - gen_tps = (len(tokens) - 1) / gen_time - print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") - print(f"Generation: {gen_tps:.3f} tokens-per-sec") + formatter = colorprint_by_t0 if args.colorize else None + + generate( + model, tokenizer, prompt, args.temp, args.max_tokens, True, formatter=formatter + ) if __name__ == "__main__": diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index bc35f9f6..522208c1 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -2,8 +2,9 @@ import copy import glob import json import logging +import time from pathlib import Path -from typing import Any, Dict, Generator, Tuple, Union +from typing import Any, Callable, Dict, Generator, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -80,38 +81,37 @@ def get_model_path(path_or_hf_repo: str) -> Path: def generate_step( - prompt: mx.array, model: nn.Module, temp: float = 0.0, return_probability: bool = False -) -> Generator[mx.array, None, None]: + prompt: mx.array, + model: nn.Module, + temp: float = 0.0, +) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing text based on the given prompt from the model. Args: prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. - temp (float): The temperature for sampling. If temp is 0, use max sampling. - return_probability (bool): Whether to return the probability of generated token, + temp (float): The temperature for sampling, if 0 the argmax is used. Yields: - Generator[mx.array]: A generator producing one token per call. + Generator[Tuple[mx.array, mx.array]]: A generator producing + one token and probability per call. """ def sample(logits: mx.array) -> Tuple[mx.array, float]: - prop = 1 if temp == 0: token = mx.argmax(logits, axis=-1) else: token = mx.random.categorical(logits * (1 / temp)) - if return_probability: - probs = mx.softmax(logits / temp) - prop = probs[0, token.item()] - return token, prop + prob = mx.softmax(logits / temp)[0, token] + return token, prob y = prompt cache = None while True: logits, cache = model(y[None], cache=cache) logits = logits[:, -1, :] - y, t0 = sample(logits) - yield y, t0 + y, prob = sample(logits) + yield y, prob def generate( @@ -121,6 +121,7 @@ def generate( temp: float = 0.0, max_tokens: int = 100, verbose: bool = False, + formatter: Callable = None, ) -> str: """ Generate text from the model. @@ -131,29 +132,54 @@ def generate( prompt (str): The string prompt. temp (float): The temperature for sampling (default 0). max_tokens (int): The maximum number of tokens (default 100). + verbose (bool): If ``True``, print tokens and timing information + (default ``False``). + formatter (Optional[Callable]): A function which takes a token and a + probability and displays it. """ + if verbose: + print("=" * 10) + print("Prompt:", prompt) + prompt = mx.array(tokenizer.encode(prompt)) + tic = time.time() tokens = [] skip = 0 REPLACEMENT_CHAR = "\ufffd" - for token, _ in zip(generate_step(prompt, model, temp), range(max_tokens)): + for (token, prob), n in zip(generate_step(prompt, model, temp), range(max_tokens)): if token == tokenizer.eos_token_id: break - + if n == 0: + prompt_time = time.time() - tic + tic = time.time() tokens.append(token.item()) if verbose: s = tokenizer.decode(tokens) - if REPLACEMENT_CHAR not in s: + if formatter: + formatter(s[skip:], prob.item()) + skip = len(s) + elif REPLACEMENT_CHAR not in s: print(s[skip:], end="", flush=True) skip = len(s) tokens = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "") + if verbose: print(tokens[skip:], flush=True) + gen_time = time.time() - tic + print("=" * 10) + if len(tokens) == 0: + print("No tokens generated for this prompt") + return + prompt_tps = prompt.size / prompt_time + gen_tps = (len(tokens) - 1) / gen_time + print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") + print(f"Generation: {gen_tps:.3f} tokens-per-sec") + return tokens