Add colorized output option to generate script (#347)

* Add colorized output option to generate script

Two new functions were added to the script that allow output to be colorized based on the T[0] probability. Changes were made to the `generate_step` function in utils.py to permit colorization. Additionally, an argument for colorization was introduced to the command-line parser.

* Rename 'colorize' parameter with 'return_probability' in generate_step
This commit is contained in:
Ivan Fioravanti 2024-01-23 14:25:44 +01:00 committed by GitHub
parent a445ac2895
commit c45c2311bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 52 additions and 12 deletions

View File

@ -51,9 +51,41 @@ def setup_arg_parser():
action="store_true", action="store_true",
help="Use the raw prompt without the tokenizer's chat template.", help="Use the raw prompt without the tokenizer's chat template.",
) )
parser.add_argument(
"--colorize",
action='store_true',
help="Colorize output based on T[0] probability",
)
return parser return parser
def colorprint(color, s):
color_codes = {
'black': 30,
'red': 31,
'green': 32,
'yellow': 33,
'blue': 34,
'magenta': 35,
'cyan': 36,
'white': 39,
}
ccode = color_codes.get(color, 30)
print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True)
def colorprint_by_t0(t0, s):
if t0 > 0.95:
color = 'white'
elif t0 > 0.70:
color = 'green'
elif t0 > 0.30:
color = 'yellow'
else:
color = 'red'
colorprint(color,s)
def main(args): def main(args):
mx.random.seed(args.seed) mx.random.seed(args.seed)
@ -83,8 +115,9 @@ def main(args):
tokens = [] tokens = []
skip = 0 skip = 0
for token, n in zip( for token, n in zip(
generate_step(prompt, model, args.temp), range(args.max_tokens) generate_step(prompt, model, args.temp, args.colorize), range(args.max_tokens)
): ):
token, t0 = token
if token == tokenizer.eos_token_id: if token == tokenizer.eos_token_id:
break break
if n == 0: if n == 0:
@ -92,6 +125,9 @@ def main(args):
tic = time.time() tic = time.time()
tokens.append(token.item()) tokens.append(token.item())
s = tokenizer.decode(tokens) s = tokenizer.decode(tokens)
if args.colorize:
colorprint_by_t0(t0,s[skip:])
else:
print(s[skip:], end="", flush=True) print(s[skip:], end="", flush=True)
skip = len(s) skip = len(s)
print(tokenizer.decode(tokens)[skip:], flush=True) print(tokenizer.decode(tokens)[skip:], flush=True)

View File

@ -77,7 +77,7 @@ def get_model_path(path_or_hf_repo: str) -> Path:
def generate_step( def generate_step(
prompt: mx.array, model: nn.Module, temp: float = 0.0 prompt: mx.array, model: nn.Module, temp: float = 0.0, return_probability: bool = False
) -> Generator[mx.array, None, None]: ) -> Generator[mx.array, None, None]:
""" """
A generator producing text based on the given prompt from the model. A generator producing text based on the given prompt from the model.
@ -86,25 +86,29 @@ def generate_step(
prompt (mx.array): The input prompt. prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation. model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling. If temp is 0, use max sampling. temp (float): The temperature for sampling. If temp is 0, use max sampling.
return_probability (bool): Whether to return the probability of generated token,
Yields: Yields:
Generator[mx.array]: A generator producing one token per call. Generator[mx.array]: A generator producing one token per call.
""" """
def sample(logits: mx.array) -> mx.array: def sample(logits: mx.array) -> Tuple[mx.array, float]:
return ( prop = 1
mx.argmax(logits, axis=-1) if temp == 0:
if temp == 0 token = mx.argmax(logits, axis=-1)
else mx.random.categorical(logits * (1 / temp)) else:
) token = mx.random.categorical(logits * (1 / temp))
if return_probability:
probs = mx.softmax(logits / temp)
prop = probs[0, token.item()]
return token, prop
y = prompt y = prompt
cache = None cache = None
while True: while True:
logits, cache = model(y[None], cache=cache) logits, cache = model(y[None], cache=cache)
logits = logits[:, -1, :] logits = logits[:, -1, :]
y = sample(logits) y, t0 = sample(logits)
yield y yield y, t0
def generate( def generate(