mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-09 18:36:38 +08:00
Add colorized output option to generate script (#347)
* Add colorized output option to generate script Two new functions were added to the script that allow output to be colorized based on the T[0] probability. Changes were made to the `generate_step` function in utils.py to permit colorization. Additionally, an argument for colorization was introduced to the command-line parser. * Rename 'colorize' parameter with 'return_probability' in generate_step
This commit is contained in:
parent
a445ac2895
commit
c45c2311bd
@ -51,9 +51,41 @@ def setup_arg_parser():
|
||||
action="store_true",
|
||||
help="Use the raw prompt without the tokenizer's chat template.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--colorize",
|
||||
action='store_true',
|
||||
help="Colorize output based on T[0] probability",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def colorprint(color, s):
|
||||
color_codes = {
|
||||
'black': 30,
|
||||
'red': 31,
|
||||
'green': 32,
|
||||
'yellow': 33,
|
||||
'blue': 34,
|
||||
'magenta': 35,
|
||||
'cyan': 36,
|
||||
'white': 39,
|
||||
}
|
||||
ccode = color_codes.get(color, 30)
|
||||
print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True)
|
||||
|
||||
|
||||
def colorprint_by_t0(t0, s):
|
||||
if t0 > 0.95:
|
||||
color = 'white'
|
||||
elif t0 > 0.70:
|
||||
color = 'green'
|
||||
elif t0 > 0.30:
|
||||
color = 'yellow'
|
||||
else:
|
||||
color = 'red'
|
||||
colorprint(color,s)
|
||||
|
||||
|
||||
def main(args):
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
@ -83,8 +115,9 @@ def main(args):
|
||||
tokens = []
|
||||
skip = 0
|
||||
for token, n in zip(
|
||||
generate_step(prompt, model, args.temp), range(args.max_tokens)
|
||||
generate_step(prompt, model, args.temp, args.colorize), range(args.max_tokens)
|
||||
):
|
||||
token, t0 = token
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
if n == 0:
|
||||
@ -92,7 +125,10 @@ def main(args):
|
||||
tic = time.time()
|
||||
tokens.append(token.item())
|
||||
s = tokenizer.decode(tokens)
|
||||
print(s[skip:], end="", flush=True)
|
||||
if args.colorize:
|
||||
colorprint_by_t0(t0,s[skip:])
|
||||
else:
|
||||
print(s[skip:], end="", flush=True)
|
||||
skip = len(s)
|
||||
print(tokenizer.decode(tokens)[skip:], flush=True)
|
||||
gen_time = time.time() - tic
|
||||
|
@ -77,7 +77,7 @@ def get_model_path(path_or_hf_repo: str) -> Path:
|
||||
|
||||
|
||||
def generate_step(
|
||||
prompt: mx.array, model: nn.Module, temp: float = 0.0
|
||||
prompt: mx.array, model: nn.Module, temp: float = 0.0, return_probability: bool = False
|
||||
) -> Generator[mx.array, None, None]:
|
||||
"""
|
||||
A generator producing text based on the given prompt from the model.
|
||||
@ -86,25 +86,29 @@ def generate_step(
|
||||
prompt (mx.array): The input prompt.
|
||||
model (nn.Module): The model to use for generation.
|
||||
temp (float): The temperature for sampling. If temp is 0, use max sampling.
|
||||
|
||||
return_probability (bool): Whether to return the probability of generated token,
|
||||
Yields:
|
||||
Generator[mx.array]: A generator producing one token per call.
|
||||
"""
|
||||
|
||||
def sample(logits: mx.array) -> mx.array:
|
||||
return (
|
||||
mx.argmax(logits, axis=-1)
|
||||
if temp == 0
|
||||
else mx.random.categorical(logits * (1 / temp))
|
||||
)
|
||||
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
||||
prop = 1
|
||||
if temp == 0:
|
||||
token = mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
token = mx.random.categorical(logits * (1 / temp))
|
||||
if return_probability:
|
||||
probs = mx.softmax(logits / temp)
|
||||
prop = probs[0, token.item()]
|
||||
return token, prop
|
||||
|
||||
y = prompt
|
||||
cache = None
|
||||
while True:
|
||||
logits, cache = model(y[None], cache=cache)
|
||||
logits = logits[:, -1, :]
|
||||
y = sample(logits)
|
||||
yield y
|
||||
y, t0 = sample(logits)
|
||||
yield y, t0
|
||||
|
||||
|
||||
def generate(
|
||||
|
Loading…
Reference in New Issue
Block a user