mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Add colorized output option to generate script (#347)
* Add colorized output option to generate script Two new functions were added to the script that allow output to be colorized based on the T[0] probability. Changes were made to the `generate_step` function in utils.py to permit colorization. Additionally, an argument for colorization was introduced to the command-line parser. * Rename 'colorize' parameter with 'return_probability' in generate_step
This commit is contained in:
@@ -77,7 +77,7 @@ def get_model_path(path_or_hf_repo: str) -> Path:
|
||||
|
||||
|
||||
def generate_step(
|
||||
prompt: mx.array, model: nn.Module, temp: float = 0.0
|
||||
prompt: mx.array, model: nn.Module, temp: float = 0.0, return_probability: bool = False
|
||||
) -> Generator[mx.array, None, None]:
|
||||
"""
|
||||
A generator producing text based on the given prompt from the model.
|
||||
@@ -86,25 +86,29 @@ def generate_step(
|
||||
prompt (mx.array): The input prompt.
|
||||
model (nn.Module): The model to use for generation.
|
||||
temp (float): The temperature for sampling. If temp is 0, use max sampling.
|
||||
|
||||
return_probability (bool): Whether to return the probability of generated token,
|
||||
Yields:
|
||||
Generator[mx.array]: A generator producing one token per call.
|
||||
"""
|
||||
|
||||
def sample(logits: mx.array) -> mx.array:
|
||||
return (
|
||||
mx.argmax(logits, axis=-1)
|
||||
if temp == 0
|
||||
else mx.random.categorical(logits * (1 / temp))
|
||||
)
|
||||
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
||||
prop = 1
|
||||
if temp == 0:
|
||||
token = mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
token = mx.random.categorical(logits * (1 / temp))
|
||||
if return_probability:
|
||||
probs = mx.softmax(logits / temp)
|
||||
prop = probs[0, token.item()]
|
||||
return token, prop
|
||||
|
||||
y = prompt
|
||||
cache = None
|
||||
while True:
|
||||
logits, cache = model(y[None], cache=cache)
|
||||
logits = logits[:, -1, :]
|
||||
y = sample(logits)
|
||||
yield y
|
||||
y, t0 = sample(logits)
|
||||
yield y, t0
|
||||
|
||||
|
||||
def generate(
|
||||
|
Reference in New Issue
Block a user