fix the chinese character generation as same as PR #321 (#342)

* fix the chinese character generation as same as PR #321

* reuse the generate logic to utils.py

* format

* verbose defualt

* fix conflicst with colorize and character check

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
iLoveBug 2024-01-24 04:44:23 +08:00 committed by GitHub
parent 21aa8038fb
commit 40b61c1719
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 63 additions and 66 deletions

View File

@ -1,9 +1,8 @@
import argparse
import time
import mlx.core as mx
from .utils import generate_step, load
from .utils import generate, load
DEFAULT_MODEL_PATH = "mlx_model"
DEFAULT_PROMPT = "hello"
@ -53,7 +52,7 @@ def setup_arg_parser():
)
parser.add_argument(
"--colorize",
action='store_true',
action="store_true",
help="Colorize output based on T[0] probability",
)
return parser
@ -61,28 +60,28 @@ def setup_arg_parser():
def colorprint(color, s):
color_codes = {
'black': 30,
'red': 31,
'green': 32,
'yellow': 33,
'blue': 34,
'magenta': 35,
'cyan': 36,
'white': 39,
"black": 30,
"red": 31,
"green": 32,
"yellow": 33,
"blue": 34,
"magenta": 35,
"cyan": 36,
"white": 39,
}
ccode = color_codes.get(color, 30)
print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True)
def colorprint_by_t0(t0, s):
def colorprint_by_t0(s, t0):
if t0 > 0.95:
color = 'white'
color = "white"
elif t0 > 0.70:
color = 'green'
color = "green"
elif t0 > 0.30:
color = 'yellow'
color = "yellow"
else:
color = 'red'
color = "red"
colorprint(color, s)
@ -107,39 +106,11 @@ def main(args):
else:
prompt = args.prompt
print("=" * 10)
print("Prompt:", prompt)
prompt = tokenizer.encode(prompt)
prompt = mx.array(prompt)
tic = time.time()
tokens = []
skip = 0
for token, n in zip(
generate_step(prompt, model, args.temp, args.colorize), range(args.max_tokens)
):
token, t0 = token
if token == tokenizer.eos_token_id:
break
if n == 0:
prompt_time = time.time() - tic
tic = time.time()
tokens.append(token.item())
s = tokenizer.decode(tokens)
if args.colorize:
colorprint_by_t0(t0,s[skip:])
else:
print(s[skip:], end="", flush=True)
skip = len(s)
print(tokenizer.decode(tokens)[skip:], flush=True)
gen_time = time.time() - tic
print("=" * 10)
if len(tokens) == 0:
print("No tokens generated for this prompt")
return
prompt_tps = prompt.size / prompt_time
gen_tps = (len(tokens) - 1) / gen_time
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
formatter = colorprint_by_t0 if args.colorize else None
generate(
model, tokenizer, prompt, args.temp, args.max_tokens, True, formatter=formatter
)
if __name__ == "__main__":

View File

@ -2,8 +2,9 @@ import copy
import glob
import json
import logging
import time
from pathlib import Path
from typing import Any, Dict, Generator, Tuple, Union
from typing import Any, Callable, Dict, Generator, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@ -80,38 +81,37 @@ def get_model_path(path_or_hf_repo: str) -> Path:
def generate_step(
prompt: mx.array, model: nn.Module, temp: float = 0.0, return_probability: bool = False
) -> Generator[mx.array, None, None]:
prompt: mx.array,
model: nn.Module,
temp: float = 0.0,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing text based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling. If temp is 0, use max sampling.
return_probability (bool): Whether to return the probability of generated token,
temp (float): The temperature for sampling, if 0 the argmax is used.
Yields:
Generator[mx.array]: A generator producing one token per call.
Generator[Tuple[mx.array, mx.array]]: A generator producing
one token and probability per call.
"""
def sample(logits: mx.array) -> Tuple[mx.array, float]:
prop = 1
if temp == 0:
token = mx.argmax(logits, axis=-1)
else:
token = mx.random.categorical(logits * (1 / temp))
if return_probability:
probs = mx.softmax(logits / temp)
prop = probs[0, token.item()]
return token, prop
prob = mx.softmax(logits / temp)[0, token]
return token, prob
y = prompt
cache = None
while True:
logits, cache = model(y[None], cache=cache)
logits = logits[:, -1, :]
y, t0 = sample(logits)
yield y, t0
y, prob = sample(logits)
yield y, prob
def generate(
@ -121,6 +121,7 @@ def generate(
temp: float = 0.0,
max_tokens: int = 100,
verbose: bool = False,
formatter: Callable = None,
) -> str:
"""
Generate text from the model.
@ -131,29 +132,54 @@ def generate(
prompt (str): The string prompt.
temp (float): The temperature for sampling (default 0).
max_tokens (int): The maximum number of tokens (default 100).
verbose (bool): If ``True``, print tokens and timing information
(default ``False``).
formatter (Optional[Callable]): A function which takes a token and a
probability and displays it.
"""
if verbose:
print("=" * 10)
print("Prompt:", prompt)
prompt = mx.array(tokenizer.encode(prompt))
tic = time.time()
tokens = []
skip = 0
REPLACEMENT_CHAR = "\ufffd"
for token, _ in zip(generate_step(prompt, model, temp), range(max_tokens)):
for (token, prob), n in zip(generate_step(prompt, model, temp), range(max_tokens)):
if token == tokenizer.eos_token_id:
break
if n == 0:
prompt_time = time.time() - tic
tic = time.time()
tokens.append(token.item())
if verbose:
s = tokenizer.decode(tokens)
if REPLACEMENT_CHAR not in s:
if formatter:
formatter(s[skip:], prob.item())
skip = len(s)
elif REPLACEMENT_CHAR not in s:
print(s[skip:], end="", flush=True)
skip = len(s)
tokens = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
if verbose:
print(tokens[skip:], flush=True)
gen_time = time.time() - tic
print("=" * 10)
if len(tokens) == 0:
print("No tokens generated for this prompt")
return
prompt_tps = prompt.size / prompt_time
gen_tps = (len(tokens) - 1) / gen_time
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
return tokens