diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 31c06eb4..a6e89afa 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -51,9 +51,41 @@ def setup_arg_parser(): action="store_true", help="Use the raw prompt without the tokenizer's chat template.", ) + parser.add_argument( + "--colorize", + action='store_true', + help="Colorize output based on T[0] probability", + ) return parser +def colorprint(color, s): + color_codes = { + 'black': 30, + 'red': 31, + 'green': 32, + 'yellow': 33, + 'blue': 34, + 'magenta': 35, + 'cyan': 36, + 'white': 39, + } + ccode = color_codes.get(color, 30) + print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True) + + +def colorprint_by_t0(t0, s): + if t0 > 0.95: + color = 'white' + elif t0 > 0.70: + color = 'green' + elif t0 > 0.30: + color = 'yellow' + else: + color = 'red' + colorprint(color,s) + + def main(args): mx.random.seed(args.seed) @@ -83,8 +115,9 @@ def main(args): tokens = [] skip = 0 for token, n in zip( - generate_step(prompt, model, args.temp), range(args.max_tokens) + generate_step(prompt, model, args.temp, args.colorize), range(args.max_tokens) ): + token, t0 = token if token == tokenizer.eos_token_id: break if n == 0: @@ -92,7 +125,10 @@ def main(args): tic = time.time() tokens.append(token.item()) s = tokenizer.decode(tokens) - print(s[skip:], end="", flush=True) + if args.colorize: + colorprint_by_t0(t0,s[skip:]) + else: + print(s[skip:], end="", flush=True) skip = len(s) print(tokenizer.decode(tokens)[skip:], flush=True) gen_time = time.time() - tic diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 5e8f8e2a..e0a877bb 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -77,7 +77,7 @@ def get_model_path(path_or_hf_repo: str) -> Path: def generate_step( - prompt: mx.array, model: nn.Module, temp: float = 0.0 + prompt: mx.array, model: nn.Module, temp: float = 0.0, return_probability: bool = False ) -> Generator[mx.array, None, None]: """ A generator producing text based on the given prompt from the model. @@ -86,25 +86,29 @@ def generate_step( prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. temp (float): The temperature for sampling. If temp is 0, use max sampling. - + return_probability (bool): Whether to return the probability of generated token, Yields: Generator[mx.array]: A generator producing one token per call. """ - def sample(logits: mx.array) -> mx.array: - return ( - mx.argmax(logits, axis=-1) - if temp == 0 - else mx.random.categorical(logits * (1 / temp)) - ) + def sample(logits: mx.array) -> Tuple[mx.array, float]: + prop = 1 + if temp == 0: + token = mx.argmax(logits, axis=-1) + else: + token = mx.random.categorical(logits * (1 / temp)) + if return_probability: + probs = mx.softmax(logits / temp) + prop = probs[0, token.item()] + return token, prop y = prompt cache = None while True: logits, cache = model(y[None], cache=cache) logits = logits[:, -1, :] - y = sample(logits) - yield y + y, t0 = sample(logits) + yield y, t0 def generate(