fix the chinese character generation as same as PR #321 (#342)

* fix the chinese character generation as same as PR #321

* reuse the generate logic to utils.py

* format

* verbose defualt

* fix conflicst with colorize and character check

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
iLoveBug 2024-01-24 04:44:23 +08:00 committed by GitHub
parent 21aa8038fb
commit 40b61c1719
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 63 additions and 66 deletions

View File

@ -1,9 +1,8 @@
import argparse import argparse
import time
import mlx.core as mx import mlx.core as mx
from .utils import generate_step, load from .utils import generate, load
DEFAULT_MODEL_PATH = "mlx_model" DEFAULT_MODEL_PATH = "mlx_model"
DEFAULT_PROMPT = "hello" DEFAULT_PROMPT = "hello"
@ -53,7 +52,7 @@ def setup_arg_parser():
) )
parser.add_argument( parser.add_argument(
"--colorize", "--colorize",
action='store_true', action="store_true",
help="Colorize output based on T[0] probability", help="Colorize output based on T[0] probability",
) )
return parser return parser
@ -61,29 +60,29 @@ def setup_arg_parser():
def colorprint(color, s): def colorprint(color, s):
color_codes = { color_codes = {
'black': 30, "black": 30,
'red': 31, "red": 31,
'green': 32, "green": 32,
'yellow': 33, "yellow": 33,
'blue': 34, "blue": 34,
'magenta': 35, "magenta": 35,
'cyan': 36, "cyan": 36,
'white': 39, "white": 39,
} }
ccode = color_codes.get(color, 30) ccode = color_codes.get(color, 30)
print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True) print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True)
def colorprint_by_t0(t0, s): def colorprint_by_t0(s, t0):
if t0 > 0.95: if t0 > 0.95:
color = 'white' color = "white"
elif t0 > 0.70: elif t0 > 0.70:
color = 'green' color = "green"
elif t0 > 0.30: elif t0 > 0.30:
color = 'yellow' color = "yellow"
else: else:
color = 'red' color = "red"
colorprint(color,s) colorprint(color, s)
def main(args): def main(args):
@ -107,39 +106,11 @@ def main(args):
else: else:
prompt = args.prompt prompt = args.prompt
print("=" * 10) formatter = colorprint_by_t0 if args.colorize else None
print("Prompt:", prompt)
prompt = tokenizer.encode(prompt) generate(
prompt = mx.array(prompt) model, tokenizer, prompt, args.temp, args.max_tokens, True, formatter=formatter
tic = time.time() )
tokens = []
skip = 0
for token, n in zip(
generate_step(prompt, model, args.temp, args.colorize), range(args.max_tokens)
):
token, t0 = token
if token == tokenizer.eos_token_id:
break
if n == 0:
prompt_time = time.time() - tic
tic = time.time()
tokens.append(token.item())
s = tokenizer.decode(tokens)
if args.colorize:
colorprint_by_t0(t0,s[skip:])
else:
print(s[skip:], end="", flush=True)
skip = len(s)
print(tokenizer.decode(tokens)[skip:], flush=True)
gen_time = time.time() - tic
print("=" * 10)
if len(tokens) == 0:
print("No tokens generated for this prompt")
return
prompt_tps = prompt.size / prompt_time
gen_tps = (len(tokens) - 1) / gen_time
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -2,8 +2,9 @@ import copy
import glob import glob
import json import json
import logging import logging
import time
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Generator, Tuple, Union from typing import Any, Callable, Dict, Generator, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -80,38 +81,37 @@ def get_model_path(path_or_hf_repo: str) -> Path:
def generate_step( def generate_step(
prompt: mx.array, model: nn.Module, temp: float = 0.0, return_probability: bool = False prompt: mx.array,
) -> Generator[mx.array, None, None]: model: nn.Module,
temp: float = 0.0,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
""" """
A generator producing text based on the given prompt from the model. A generator producing text based on the given prompt from the model.
Args: Args:
prompt (mx.array): The input prompt. prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation. model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling. If temp is 0, use max sampling. temp (float): The temperature for sampling, if 0 the argmax is used.
return_probability (bool): Whether to return the probability of generated token,
Yields: Yields:
Generator[mx.array]: A generator producing one token per call. Generator[Tuple[mx.array, mx.array]]: A generator producing
one token and probability per call.
""" """
def sample(logits: mx.array) -> Tuple[mx.array, float]: def sample(logits: mx.array) -> Tuple[mx.array, float]:
prop = 1
if temp == 0: if temp == 0:
token = mx.argmax(logits, axis=-1) token = mx.argmax(logits, axis=-1)
else: else:
token = mx.random.categorical(logits * (1 / temp)) token = mx.random.categorical(logits * (1 / temp))
if return_probability: prob = mx.softmax(logits / temp)[0, token]
probs = mx.softmax(logits / temp) return token, prob
prop = probs[0, token.item()]
return token, prop
y = prompt y = prompt
cache = None cache = None
while True: while True:
logits, cache = model(y[None], cache=cache) logits, cache = model(y[None], cache=cache)
logits = logits[:, -1, :] logits = logits[:, -1, :]
y, t0 = sample(logits) y, prob = sample(logits)
yield y, t0 yield y, prob
def generate( def generate(
@ -121,6 +121,7 @@ def generate(
temp: float = 0.0, temp: float = 0.0,
max_tokens: int = 100, max_tokens: int = 100,
verbose: bool = False, verbose: bool = False,
formatter: Callable = None,
) -> str: ) -> str:
""" """
Generate text from the model. Generate text from the model.
@ -131,29 +132,54 @@ def generate(
prompt (str): The string prompt. prompt (str): The string prompt.
temp (float): The temperature for sampling (default 0). temp (float): The temperature for sampling (default 0).
max_tokens (int): The maximum number of tokens (default 100). max_tokens (int): The maximum number of tokens (default 100).
verbose (bool): If ``True``, print tokens and timing information
(default ``False``).
formatter (Optional[Callable]): A function which takes a token and a
probability and displays it.
""" """
if verbose:
print("=" * 10)
print("Prompt:", prompt)
prompt = mx.array(tokenizer.encode(prompt)) prompt = mx.array(tokenizer.encode(prompt))
tic = time.time()
tokens = [] tokens = []
skip = 0 skip = 0
REPLACEMENT_CHAR = "\ufffd" REPLACEMENT_CHAR = "\ufffd"
for token, _ in zip(generate_step(prompt, model, temp), range(max_tokens)): for (token, prob), n in zip(generate_step(prompt, model, temp), range(max_tokens)):
if token == tokenizer.eos_token_id: if token == tokenizer.eos_token_id:
break break
if n == 0:
prompt_time = time.time() - tic
tic = time.time()
tokens.append(token.item()) tokens.append(token.item())
if verbose: if verbose:
s = tokenizer.decode(tokens) s = tokenizer.decode(tokens)
if REPLACEMENT_CHAR not in s: if formatter:
formatter(s[skip:], prob.item())
skip = len(s)
elif REPLACEMENT_CHAR not in s:
print(s[skip:], end="", flush=True) print(s[skip:], end="", flush=True)
skip = len(s) skip = len(s)
tokens = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "") tokens = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
if verbose: if verbose:
print(tokens[skip:], flush=True) print(tokens[skip:], flush=True)
gen_time = time.time() - tic
print("=" * 10)
if len(tokens) == 0:
print("No tokens generated for this prompt")
return
prompt_tps = prompt.size / prompt_time
gen_tps = (len(tokens) - 1) / gen_time
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
return tokens return tokens