mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-15 01:42:31 +08:00
Add streaming detokenizers (#651)
This commit is contained in:
committed by
GitHub
parent
c68aa3c7c3
commit
1278994b56
@@ -17,9 +17,9 @@ from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_flatten
|
||||
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
from .sample_utils import top_p_sampling
|
||||
|
||||
# Local imports
|
||||
from .sample_utils import top_p_sampling
|
||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||
from .tuner.utils import apply_lora_layers
|
||||
from .tuner.utils import dequantize as dequantize_model
|
||||
|
||||
@@ -189,7 +189,7 @@ def generate_step(
|
||||
|
||||
def generate(
|
||||
model: nn.Module,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||
prompt: str,
|
||||
temp: float = 0.0,
|
||||
max_tokens: int = 100,
|
||||
@@ -215,18 +215,18 @@ def generate(
|
||||
repetition_penalty (float, optional): The penalty factor for repeating tokens.
|
||||
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty.
|
||||
"""
|
||||
if not isinstance(tokenizer, TokenizerWrapper):
|
||||
tokenizer = TokenizerWrapper(tokenizer)
|
||||
|
||||
if verbose:
|
||||
print("=" * 10)
|
||||
print("Prompt:", prompt)
|
||||
|
||||
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
||||
detokenizer = tokenizer.detokenizer
|
||||
|
||||
tic = time.perf_counter()
|
||||
tokens = []
|
||||
token_strings = []
|
||||
skip = 0
|
||||
REPLACEMENT_CHAR = "\ufffd"
|
||||
detokenizer.reset()
|
||||
|
||||
for (token, prob), n in zip(
|
||||
generate_step(
|
||||
@@ -245,29 +245,21 @@ def generate(
|
||||
tic = time.perf_counter()
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
tokens.append(token)
|
||||
detokenizer.add_token(token)
|
||||
|
||||
if verbose:
|
||||
s = tokenizer.decode(tokens)
|
||||
if not s:
|
||||
continue
|
||||
elif formatter:
|
||||
formatter(s[skip:], prob.item())
|
||||
skip = len(s)
|
||||
elif s[-1] != REPLACEMENT_CHAR:
|
||||
print(s[skip:], end="", flush=True)
|
||||
skip = len(s)
|
||||
# Reset token cache at line break
|
||||
if s[-1] == "\n":
|
||||
tokens = []
|
||||
token_strings.append(s)
|
||||
skip = 0
|
||||
if formatter:
|
||||
# We have to finalize so that the prob corresponds to the last segment
|
||||
detokenizer.finalize()
|
||||
formatter(detokenizer.last_segment, prob.item())
|
||||
else:
|
||||
print(detokenizer.last_segment, end="", flush=True)
|
||||
|
||||
token_count = n + 1
|
||||
token_strings.append(tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, ""))
|
||||
detokenizer.finalize()
|
||||
|
||||
if verbose:
|
||||
print(token_strings[-1][skip:], flush=True)
|
||||
print(detokenizer.last_segment, flush=True)
|
||||
gen_time = time.perf_counter() - tic
|
||||
print("=" * 10)
|
||||
if token_count == 0:
|
||||
@@ -278,7 +270,7 @@ def generate(
|
||||
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||
|
||||
return "".join(token_strings)
|
||||
return detokenizer.text
|
||||
|
||||
|
||||
def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
|
||||
@@ -384,8 +376,8 @@ def load(
|
||||
if adapter_path is not None:
|
||||
model = apply_lora_layers(model, adapter_path)
|
||||
model.eval()
|
||||
tokenizer = load_tokenizer(model_path, tokenizer_config)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
@@ -394,7 +386,7 @@ def fetch_from_hub(
|
||||
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
||||
model = load_model(model_path, lazy)
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
tokenizer = load_tokenizer(model_path)
|
||||
|
||||
return model, config.to_dict(), tokenizer
|
||||
|
||||
|
||||
Reference in New Issue
Block a user