mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-15 01:42:31 +08:00
* fix the chinese character generation as same as PR #321 * reuse the generate logic to utils.py * format * verbose defualt * fix conflicst with colorize and character check --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -2,8 +2,9 @@ import copy
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Generator, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -80,38 +81,37 @@ def get_model_path(path_or_hf_repo: str) -> Path:
|
||||
|
||||
|
||||
def generate_step(
|
||||
prompt: mx.array, model: nn.Module, temp: float = 0.0, return_probability: bool = False
|
||||
) -> Generator[mx.array, None, None]:
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
temp: float = 0.0,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
"""
|
||||
A generator producing text based on the given prompt from the model.
|
||||
|
||||
Args:
|
||||
prompt (mx.array): The input prompt.
|
||||
model (nn.Module): The model to use for generation.
|
||||
temp (float): The temperature for sampling. If temp is 0, use max sampling.
|
||||
return_probability (bool): Whether to return the probability of generated token,
|
||||
temp (float): The temperature for sampling, if 0 the argmax is used.
|
||||
Yields:
|
||||
Generator[mx.array]: A generator producing one token per call.
|
||||
Generator[Tuple[mx.array, mx.array]]: A generator producing
|
||||
one token and probability per call.
|
||||
"""
|
||||
|
||||
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
||||
prop = 1
|
||||
if temp == 0:
|
||||
token = mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
token = mx.random.categorical(logits * (1 / temp))
|
||||
if return_probability:
|
||||
probs = mx.softmax(logits / temp)
|
||||
prop = probs[0, token.item()]
|
||||
return token, prop
|
||||
prob = mx.softmax(logits / temp)[0, token]
|
||||
return token, prob
|
||||
|
||||
y = prompt
|
||||
cache = None
|
||||
while True:
|
||||
logits, cache = model(y[None], cache=cache)
|
||||
logits = logits[:, -1, :]
|
||||
y, t0 = sample(logits)
|
||||
yield y, t0
|
||||
y, prob = sample(logits)
|
||||
yield y, prob
|
||||
|
||||
|
||||
def generate(
|
||||
@@ -121,6 +121,7 @@ def generate(
|
||||
temp: float = 0.0,
|
||||
max_tokens: int = 100,
|
||||
verbose: bool = False,
|
||||
formatter: Callable = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate text from the model.
|
||||
@@ -131,29 +132,54 @@ def generate(
|
||||
prompt (str): The string prompt.
|
||||
temp (float): The temperature for sampling (default 0).
|
||||
max_tokens (int): The maximum number of tokens (default 100).
|
||||
verbose (bool): If ``True``, print tokens and timing information
|
||||
(default ``False``).
|
||||
formatter (Optional[Callable]): A function which takes a token and a
|
||||
probability and displays it.
|
||||
"""
|
||||
|
||||
if verbose:
|
||||
print("=" * 10)
|
||||
print("Prompt:", prompt)
|
||||
|
||||
prompt = mx.array(tokenizer.encode(prompt))
|
||||
|
||||
tic = time.time()
|
||||
tokens = []
|
||||
skip = 0
|
||||
REPLACEMENT_CHAR = "\ufffd"
|
||||
|
||||
for token, _ in zip(generate_step(prompt, model, temp), range(max_tokens)):
|
||||
for (token, prob), n in zip(generate_step(prompt, model, temp), range(max_tokens)):
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
if n == 0:
|
||||
prompt_time = time.time() - tic
|
||||
tic = time.time()
|
||||
tokens.append(token.item())
|
||||
|
||||
if verbose:
|
||||
s = tokenizer.decode(tokens)
|
||||
if REPLACEMENT_CHAR not in s:
|
||||
if formatter:
|
||||
formatter(s[skip:], prob.item())
|
||||
skip = len(s)
|
||||
elif REPLACEMENT_CHAR not in s:
|
||||
print(s[skip:], end="", flush=True)
|
||||
skip = len(s)
|
||||
|
||||
tokens = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
|
||||
|
||||
if verbose:
|
||||
print(tokens[skip:], flush=True)
|
||||
gen_time = time.time() - tic
|
||||
print("=" * 10)
|
||||
if len(tokens) == 0:
|
||||
print("No tokens generated for this prompt")
|
||||
return
|
||||
prompt_tps = prompt.size / prompt_time
|
||||
gen_tps = (len(tokens) - 1) / gen_time
|
||||
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user