mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
* fix the chinese character generation as same as PR #321 * reuse the generate logic to utils.py * format * verbose defualt * fix conflicst with colorize and character check --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
21aa8038fb
commit
40b61c1719
@ -1,9 +1,8 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import time
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
from .utils import generate_step, load
|
from .utils import generate, load
|
||||||
|
|
||||||
DEFAULT_MODEL_PATH = "mlx_model"
|
DEFAULT_MODEL_PATH = "mlx_model"
|
||||||
DEFAULT_PROMPT = "hello"
|
DEFAULT_PROMPT = "hello"
|
||||||
@ -53,7 +52,7 @@ def setup_arg_parser():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--colorize",
|
"--colorize",
|
||||||
action='store_true',
|
action="store_true",
|
||||||
help="Colorize output based on T[0] probability",
|
help="Colorize output based on T[0] probability",
|
||||||
)
|
)
|
||||||
return parser
|
return parser
|
||||||
@ -61,29 +60,29 @@ def setup_arg_parser():
|
|||||||
|
|
||||||
def colorprint(color, s):
|
def colorprint(color, s):
|
||||||
color_codes = {
|
color_codes = {
|
||||||
'black': 30,
|
"black": 30,
|
||||||
'red': 31,
|
"red": 31,
|
||||||
'green': 32,
|
"green": 32,
|
||||||
'yellow': 33,
|
"yellow": 33,
|
||||||
'blue': 34,
|
"blue": 34,
|
||||||
'magenta': 35,
|
"magenta": 35,
|
||||||
'cyan': 36,
|
"cyan": 36,
|
||||||
'white': 39,
|
"white": 39,
|
||||||
}
|
}
|
||||||
ccode = color_codes.get(color, 30)
|
ccode = color_codes.get(color, 30)
|
||||||
print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True)
|
print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True)
|
||||||
|
|
||||||
|
|
||||||
def colorprint_by_t0(t0, s):
|
def colorprint_by_t0(s, t0):
|
||||||
if t0 > 0.95:
|
if t0 > 0.95:
|
||||||
color = 'white'
|
color = "white"
|
||||||
elif t0 > 0.70:
|
elif t0 > 0.70:
|
||||||
color = 'green'
|
color = "green"
|
||||||
elif t0 > 0.30:
|
elif t0 > 0.30:
|
||||||
color = 'yellow'
|
color = "yellow"
|
||||||
else:
|
else:
|
||||||
color = 'red'
|
color = "red"
|
||||||
colorprint(color,s)
|
colorprint(color, s)
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
@ -107,39 +106,11 @@ def main(args):
|
|||||||
else:
|
else:
|
||||||
prompt = args.prompt
|
prompt = args.prompt
|
||||||
|
|
||||||
print("=" * 10)
|
formatter = colorprint_by_t0 if args.colorize else None
|
||||||
print("Prompt:", prompt)
|
|
||||||
prompt = tokenizer.encode(prompt)
|
generate(
|
||||||
prompt = mx.array(prompt)
|
model, tokenizer, prompt, args.temp, args.max_tokens, True, formatter=formatter
|
||||||
tic = time.time()
|
)
|
||||||
tokens = []
|
|
||||||
skip = 0
|
|
||||||
for token, n in zip(
|
|
||||||
generate_step(prompt, model, args.temp, args.colorize), range(args.max_tokens)
|
|
||||||
):
|
|
||||||
token, t0 = token
|
|
||||||
if token == tokenizer.eos_token_id:
|
|
||||||
break
|
|
||||||
if n == 0:
|
|
||||||
prompt_time = time.time() - tic
|
|
||||||
tic = time.time()
|
|
||||||
tokens.append(token.item())
|
|
||||||
s = tokenizer.decode(tokens)
|
|
||||||
if args.colorize:
|
|
||||||
colorprint_by_t0(t0,s[skip:])
|
|
||||||
else:
|
|
||||||
print(s[skip:], end="", flush=True)
|
|
||||||
skip = len(s)
|
|
||||||
print(tokenizer.decode(tokens)[skip:], flush=True)
|
|
||||||
gen_time = time.time() - tic
|
|
||||||
print("=" * 10)
|
|
||||||
if len(tokens) == 0:
|
|
||||||
print("No tokens generated for this prompt")
|
|
||||||
return
|
|
||||||
prompt_tps = prompt.size / prompt_time
|
|
||||||
gen_tps = (len(tokens) - 1) / gen_time
|
|
||||||
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
|
||||||
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -2,8 +2,9 @@ import copy
|
|||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Generator, Tuple, Union
|
from typing import Any, Callable, Dict, Generator, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -80,38 +81,37 @@ def get_model_path(path_or_hf_repo: str) -> Path:
|
|||||||
|
|
||||||
|
|
||||||
def generate_step(
|
def generate_step(
|
||||||
prompt: mx.array, model: nn.Module, temp: float = 0.0, return_probability: bool = False
|
prompt: mx.array,
|
||||||
) -> Generator[mx.array, None, None]:
|
model: nn.Module,
|
||||||
|
temp: float = 0.0,
|
||||||
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||||
"""
|
"""
|
||||||
A generator producing text based on the given prompt from the model.
|
A generator producing text based on the given prompt from the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt (mx.array): The input prompt.
|
prompt (mx.array): The input prompt.
|
||||||
model (nn.Module): The model to use for generation.
|
model (nn.Module): The model to use for generation.
|
||||||
temp (float): The temperature for sampling. If temp is 0, use max sampling.
|
temp (float): The temperature for sampling, if 0 the argmax is used.
|
||||||
return_probability (bool): Whether to return the probability of generated token,
|
|
||||||
Yields:
|
Yields:
|
||||||
Generator[mx.array]: A generator producing one token per call.
|
Generator[Tuple[mx.array, mx.array]]: A generator producing
|
||||||
|
one token and probability per call.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
||||||
prop = 1
|
|
||||||
if temp == 0:
|
if temp == 0:
|
||||||
token = mx.argmax(logits, axis=-1)
|
token = mx.argmax(logits, axis=-1)
|
||||||
else:
|
else:
|
||||||
token = mx.random.categorical(logits * (1 / temp))
|
token = mx.random.categorical(logits * (1 / temp))
|
||||||
if return_probability:
|
prob = mx.softmax(logits / temp)[0, token]
|
||||||
probs = mx.softmax(logits / temp)
|
return token, prob
|
||||||
prop = probs[0, token.item()]
|
|
||||||
return token, prop
|
|
||||||
|
|
||||||
y = prompt
|
y = prompt
|
||||||
cache = None
|
cache = None
|
||||||
while True:
|
while True:
|
||||||
logits, cache = model(y[None], cache=cache)
|
logits, cache = model(y[None], cache=cache)
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
y, t0 = sample(logits)
|
y, prob = sample(logits)
|
||||||
yield y, t0
|
yield y, prob
|
||||||
|
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
@ -121,6 +121,7 @@ def generate(
|
|||||||
temp: float = 0.0,
|
temp: float = 0.0,
|
||||||
max_tokens: int = 100,
|
max_tokens: int = 100,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
|
formatter: Callable = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate text from the model.
|
Generate text from the model.
|
||||||
@ -131,29 +132,54 @@ def generate(
|
|||||||
prompt (str): The string prompt.
|
prompt (str): The string prompt.
|
||||||
temp (float): The temperature for sampling (default 0).
|
temp (float): The temperature for sampling (default 0).
|
||||||
max_tokens (int): The maximum number of tokens (default 100).
|
max_tokens (int): The maximum number of tokens (default 100).
|
||||||
|
verbose (bool): If ``True``, print tokens and timing information
|
||||||
|
(default ``False``).
|
||||||
|
formatter (Optional[Callable]): A function which takes a token and a
|
||||||
|
probability and displays it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print("=" * 10)
|
||||||
|
print("Prompt:", prompt)
|
||||||
|
|
||||||
prompt = mx.array(tokenizer.encode(prompt))
|
prompt = mx.array(tokenizer.encode(prompt))
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
tokens = []
|
tokens = []
|
||||||
skip = 0
|
skip = 0
|
||||||
REPLACEMENT_CHAR = "\ufffd"
|
REPLACEMENT_CHAR = "\ufffd"
|
||||||
|
|
||||||
for token, _ in zip(generate_step(prompt, model, temp), range(max_tokens)):
|
for (token, prob), n in zip(generate_step(prompt, model, temp), range(max_tokens)):
|
||||||
if token == tokenizer.eos_token_id:
|
if token == tokenizer.eos_token_id:
|
||||||
break
|
break
|
||||||
|
if n == 0:
|
||||||
|
prompt_time = time.time() - tic
|
||||||
|
tic = time.time()
|
||||||
tokens.append(token.item())
|
tokens.append(token.item())
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
s = tokenizer.decode(tokens)
|
s = tokenizer.decode(tokens)
|
||||||
if REPLACEMENT_CHAR not in s:
|
if formatter:
|
||||||
|
formatter(s[skip:], prob.item())
|
||||||
|
skip = len(s)
|
||||||
|
elif REPLACEMENT_CHAR not in s:
|
||||||
print(s[skip:], end="", flush=True)
|
print(s[skip:], end="", flush=True)
|
||||||
skip = len(s)
|
skip = len(s)
|
||||||
|
|
||||||
tokens = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
|
tokens = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(tokens[skip:], flush=True)
|
print(tokens[skip:], flush=True)
|
||||||
|
gen_time = time.time() - tic
|
||||||
|
print("=" * 10)
|
||||||
|
if len(tokens) == 0:
|
||||||
|
print("No tokens generated for this prompt")
|
||||||
|
return
|
||||||
|
prompt_tps = prompt.size / prompt_time
|
||||||
|
gen_tps = (len(tokens) - 1) / gen_time
|
||||||
|
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||||
|
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||||
|
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user