mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-09 18:36:38 +08:00
Add colorized output option to generate script (#347)
* Add colorized output option to generate script Two new functions were added to the script that allow output to be colorized based on the T[0] probability. Changes were made to the `generate_step` function in utils.py to permit colorization. Additionally, an argument for colorization was introduced to the command-line parser. * Rename 'colorize' parameter with 'return_probability' in generate_step
This commit is contained in:
parent
a445ac2895
commit
c45c2311bd
@ -51,9 +51,41 @@ def setup_arg_parser():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Use the raw prompt without the tokenizer's chat template.",
|
help="Use the raw prompt without the tokenizer's chat template.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--colorize",
|
||||||
|
action='store_true',
|
||||||
|
help="Colorize output based on T[0] probability",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def colorprint(color, s):
|
||||||
|
color_codes = {
|
||||||
|
'black': 30,
|
||||||
|
'red': 31,
|
||||||
|
'green': 32,
|
||||||
|
'yellow': 33,
|
||||||
|
'blue': 34,
|
||||||
|
'magenta': 35,
|
||||||
|
'cyan': 36,
|
||||||
|
'white': 39,
|
||||||
|
}
|
||||||
|
ccode = color_codes.get(color, 30)
|
||||||
|
print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
def colorprint_by_t0(t0, s):
|
||||||
|
if t0 > 0.95:
|
||||||
|
color = 'white'
|
||||||
|
elif t0 > 0.70:
|
||||||
|
color = 'green'
|
||||||
|
elif t0 > 0.30:
|
||||||
|
color = 'yellow'
|
||||||
|
else:
|
||||||
|
color = 'red'
|
||||||
|
colorprint(color,s)
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
mx.random.seed(args.seed)
|
mx.random.seed(args.seed)
|
||||||
|
|
||||||
@ -83,8 +115,9 @@ def main(args):
|
|||||||
tokens = []
|
tokens = []
|
||||||
skip = 0
|
skip = 0
|
||||||
for token, n in zip(
|
for token, n in zip(
|
||||||
generate_step(prompt, model, args.temp), range(args.max_tokens)
|
generate_step(prompt, model, args.temp, args.colorize), range(args.max_tokens)
|
||||||
):
|
):
|
||||||
|
token, t0 = token
|
||||||
if token == tokenizer.eos_token_id:
|
if token == tokenizer.eos_token_id:
|
||||||
break
|
break
|
||||||
if n == 0:
|
if n == 0:
|
||||||
@ -92,7 +125,10 @@ def main(args):
|
|||||||
tic = time.time()
|
tic = time.time()
|
||||||
tokens.append(token.item())
|
tokens.append(token.item())
|
||||||
s = tokenizer.decode(tokens)
|
s = tokenizer.decode(tokens)
|
||||||
print(s[skip:], end="", flush=True)
|
if args.colorize:
|
||||||
|
colorprint_by_t0(t0,s[skip:])
|
||||||
|
else:
|
||||||
|
print(s[skip:], end="", flush=True)
|
||||||
skip = len(s)
|
skip = len(s)
|
||||||
print(tokenizer.decode(tokens)[skip:], flush=True)
|
print(tokenizer.decode(tokens)[skip:], flush=True)
|
||||||
gen_time = time.time() - tic
|
gen_time = time.time() - tic
|
||||||
|
@ -77,7 +77,7 @@ def get_model_path(path_or_hf_repo: str) -> Path:
|
|||||||
|
|
||||||
|
|
||||||
def generate_step(
|
def generate_step(
|
||||||
prompt: mx.array, model: nn.Module, temp: float = 0.0
|
prompt: mx.array, model: nn.Module, temp: float = 0.0, return_probability: bool = False
|
||||||
) -> Generator[mx.array, None, None]:
|
) -> Generator[mx.array, None, None]:
|
||||||
"""
|
"""
|
||||||
A generator producing text based on the given prompt from the model.
|
A generator producing text based on the given prompt from the model.
|
||||||
@ -86,25 +86,29 @@ def generate_step(
|
|||||||
prompt (mx.array): The input prompt.
|
prompt (mx.array): The input prompt.
|
||||||
model (nn.Module): The model to use for generation.
|
model (nn.Module): The model to use for generation.
|
||||||
temp (float): The temperature for sampling. If temp is 0, use max sampling.
|
temp (float): The temperature for sampling. If temp is 0, use max sampling.
|
||||||
|
return_probability (bool): Whether to return the probability of generated token,
|
||||||
Yields:
|
Yields:
|
||||||
Generator[mx.array]: A generator producing one token per call.
|
Generator[mx.array]: A generator producing one token per call.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def sample(logits: mx.array) -> mx.array:
|
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
||||||
return (
|
prop = 1
|
||||||
mx.argmax(logits, axis=-1)
|
if temp == 0:
|
||||||
if temp == 0
|
token = mx.argmax(logits, axis=-1)
|
||||||
else mx.random.categorical(logits * (1 / temp))
|
else:
|
||||||
)
|
token = mx.random.categorical(logits * (1 / temp))
|
||||||
|
if return_probability:
|
||||||
|
probs = mx.softmax(logits / temp)
|
||||||
|
prop = probs[0, token.item()]
|
||||||
|
return token, prop
|
||||||
|
|
||||||
y = prompt
|
y = prompt
|
||||||
cache = None
|
cache = None
|
||||||
while True:
|
while True:
|
||||||
logits, cache = model(y[None], cache=cache)
|
logits, cache = model(y[None], cache=cache)
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
y = sample(logits)
|
y, t0 = sample(logits)
|
||||||
yield y
|
yield y, t0
|
||||||
|
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
|
Loading…
Reference in New Issue
Block a user