Add colorized output option to generate script (#347)

* Add colorized output option to generate script

Two new functions were added to the script that allow output to be colorized based on the T[0] probability. Changes were made to the `generate_step` function in utils.py to permit colorization. Additionally, an argument for colorization was introduced to the command-line parser.

* Rename 'colorize' parameter with 'return_probability' in generate_step
This commit is contained in:
Ivan Fioravanti
2024-01-23 14:25:44 +01:00
committed by GitHub
parent a445ac2895
commit c45c2311bd
2 changed files with 52 additions and 12 deletions

View File

@@ -77,7 +77,7 @@ def get_model_path(path_or_hf_repo: str) -> Path:
def generate_step(
prompt: mx.array, model: nn.Module, temp: float = 0.0
prompt: mx.array, model: nn.Module, temp: float = 0.0, return_probability: bool = False
) -> Generator[mx.array, None, None]:
"""
A generator producing text based on the given prompt from the model.
@@ -86,25 +86,29 @@ def generate_step(
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling. If temp is 0, use max sampling.
return_probability (bool): Whether to return the probability of generated token,
Yields:
Generator[mx.array]: A generator producing one token per call.
"""
def sample(logits: mx.array) -> mx.array:
return (
mx.argmax(logits, axis=-1)
if temp == 0
else mx.random.categorical(logits * (1 / temp))
)
def sample(logits: mx.array) -> Tuple[mx.array, float]:
prop = 1
if temp == 0:
token = mx.argmax(logits, axis=-1)
else:
token = mx.random.categorical(logits * (1 / temp))
if return_probability:
probs = mx.softmax(logits / temp)
prop = probs[0, token.item()]
return token, prop
y = prompt
cache = None
while True:
logits, cache = model(y[None], cache=cache)
logits = logits[:, -1, :]
y = sample(logits)
yield y
y, t0 = sample(logits)
yield y, t0
def generate(