mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
* fix the chinese character generation as same as PR #321 * reuse the generate logic to utils.py * format * verbose defualt * fix conflicst with colorize and character check --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
21aa8038fb
commit
40b61c1719
@ -1,9 +1,8 @@
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .utils import generate_step, load
|
||||
from .utils import generate, load
|
||||
|
||||
DEFAULT_MODEL_PATH = "mlx_model"
|
||||
DEFAULT_PROMPT = "hello"
|
||||
@ -53,7 +52,7 @@ def setup_arg_parser():
|
||||
)
|
||||
parser.add_argument(
|
||||
"--colorize",
|
||||
action='store_true',
|
||||
action="store_true",
|
||||
help="Colorize output based on T[0] probability",
|
||||
)
|
||||
return parser
|
||||
@ -61,29 +60,29 @@ def setup_arg_parser():
|
||||
|
||||
def colorprint(color, s):
|
||||
color_codes = {
|
||||
'black': 30,
|
||||
'red': 31,
|
||||
'green': 32,
|
||||
'yellow': 33,
|
||||
'blue': 34,
|
||||
'magenta': 35,
|
||||
'cyan': 36,
|
||||
'white': 39,
|
||||
"black": 30,
|
||||
"red": 31,
|
||||
"green": 32,
|
||||
"yellow": 33,
|
||||
"blue": 34,
|
||||
"magenta": 35,
|
||||
"cyan": 36,
|
||||
"white": 39,
|
||||
}
|
||||
ccode = color_codes.get(color, 30)
|
||||
print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True)
|
||||
|
||||
|
||||
def colorprint_by_t0(t0, s):
|
||||
def colorprint_by_t0(s, t0):
|
||||
if t0 > 0.95:
|
||||
color = 'white'
|
||||
color = "white"
|
||||
elif t0 > 0.70:
|
||||
color = 'green'
|
||||
color = "green"
|
||||
elif t0 > 0.30:
|
||||
color = 'yellow'
|
||||
color = "yellow"
|
||||
else:
|
||||
color = 'red'
|
||||
colorprint(color,s)
|
||||
color = "red"
|
||||
colorprint(color, s)
|
||||
|
||||
|
||||
def main(args):
|
||||
@ -107,39 +106,11 @@ def main(args):
|
||||
else:
|
||||
prompt = args.prompt
|
||||
|
||||
print("=" * 10)
|
||||
print("Prompt:", prompt)
|
||||
prompt = tokenizer.encode(prompt)
|
||||
prompt = mx.array(prompt)
|
||||
tic = time.time()
|
||||
tokens = []
|
||||
skip = 0
|
||||
for token, n in zip(
|
||||
generate_step(prompt, model, args.temp, args.colorize), range(args.max_tokens)
|
||||
):
|
||||
token, t0 = token
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
if n == 0:
|
||||
prompt_time = time.time() - tic
|
||||
tic = time.time()
|
||||
tokens.append(token.item())
|
||||
s = tokenizer.decode(tokens)
|
||||
if args.colorize:
|
||||
colorprint_by_t0(t0,s[skip:])
|
||||
else:
|
||||
print(s[skip:], end="", flush=True)
|
||||
skip = len(s)
|
||||
print(tokenizer.decode(tokens)[skip:], flush=True)
|
||||
gen_time = time.time() - tic
|
||||
print("=" * 10)
|
||||
if len(tokens) == 0:
|
||||
print("No tokens generated for this prompt")
|
||||
return
|
||||
prompt_tps = prompt.size / prompt_time
|
||||
gen_tps = (len(tokens) - 1) / gen_time
|
||||
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||
formatter = colorprint_by_t0 if args.colorize else None
|
||||
|
||||
generate(
|
||||
model, tokenizer, prompt, args.temp, args.max_tokens, True, formatter=formatter
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -2,8 +2,9 @@ import copy
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Generator, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -80,38 +81,37 @@ def get_model_path(path_or_hf_repo: str) -> Path:
|
||||
|
||||
|
||||
def generate_step(
|
||||
prompt: mx.array, model: nn.Module, temp: float = 0.0, return_probability: bool = False
|
||||
) -> Generator[mx.array, None, None]:
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
temp: float = 0.0,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
"""
|
||||
A generator producing text based on the given prompt from the model.
|
||||
|
||||
Args:
|
||||
prompt (mx.array): The input prompt.
|
||||
model (nn.Module): The model to use for generation.
|
||||
temp (float): The temperature for sampling. If temp is 0, use max sampling.
|
||||
return_probability (bool): Whether to return the probability of generated token,
|
||||
temp (float): The temperature for sampling, if 0 the argmax is used.
|
||||
Yields:
|
||||
Generator[mx.array]: A generator producing one token per call.
|
||||
Generator[Tuple[mx.array, mx.array]]: A generator producing
|
||||
one token and probability per call.
|
||||
"""
|
||||
|
||||
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
||||
prop = 1
|
||||
if temp == 0:
|
||||
token = mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
token = mx.random.categorical(logits * (1 / temp))
|
||||
if return_probability:
|
||||
probs = mx.softmax(logits / temp)
|
||||
prop = probs[0, token.item()]
|
||||
return token, prop
|
||||
prob = mx.softmax(logits / temp)[0, token]
|
||||
return token, prob
|
||||
|
||||
y = prompt
|
||||
cache = None
|
||||
while True:
|
||||
logits, cache = model(y[None], cache=cache)
|
||||
logits = logits[:, -1, :]
|
||||
y, t0 = sample(logits)
|
||||
yield y, t0
|
||||
y, prob = sample(logits)
|
||||
yield y, prob
|
||||
|
||||
|
||||
def generate(
|
||||
@ -121,6 +121,7 @@ def generate(
|
||||
temp: float = 0.0,
|
||||
max_tokens: int = 100,
|
||||
verbose: bool = False,
|
||||
formatter: Callable = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate text from the model.
|
||||
@ -131,29 +132,54 @@ def generate(
|
||||
prompt (str): The string prompt.
|
||||
temp (float): The temperature for sampling (default 0).
|
||||
max_tokens (int): The maximum number of tokens (default 100).
|
||||
verbose (bool): If ``True``, print tokens and timing information
|
||||
(default ``False``).
|
||||
formatter (Optional[Callable]): A function which takes a token and a
|
||||
probability and displays it.
|
||||
"""
|
||||
|
||||
if verbose:
|
||||
print("=" * 10)
|
||||
print("Prompt:", prompt)
|
||||
|
||||
prompt = mx.array(tokenizer.encode(prompt))
|
||||
|
||||
tic = time.time()
|
||||
tokens = []
|
||||
skip = 0
|
||||
REPLACEMENT_CHAR = "\ufffd"
|
||||
|
||||
for token, _ in zip(generate_step(prompt, model, temp), range(max_tokens)):
|
||||
for (token, prob), n in zip(generate_step(prompt, model, temp), range(max_tokens)):
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
if n == 0:
|
||||
prompt_time = time.time() - tic
|
||||
tic = time.time()
|
||||
tokens.append(token.item())
|
||||
|
||||
if verbose:
|
||||
s = tokenizer.decode(tokens)
|
||||
if REPLACEMENT_CHAR not in s:
|
||||
if formatter:
|
||||
formatter(s[skip:], prob.item())
|
||||
skip = len(s)
|
||||
elif REPLACEMENT_CHAR not in s:
|
||||
print(s[skip:], end="", flush=True)
|
||||
skip = len(s)
|
||||
|
||||
tokens = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
|
||||
|
||||
if verbose:
|
||||
print(tokens[skip:], flush=True)
|
||||
gen_time = time.time() - tic
|
||||
print("=" * 10)
|
||||
if len(tokens) == 0:
|
||||
print("No tokens generated for this prompt")
|
||||
return
|
||||
prompt_tps = prompt.size / prompt_time
|
||||
gen_tps = (len(tokens) - 1) / gen_time
|
||||
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user