Add streaming detokenizers (#651)

This commit is contained in:
Angelos Katharopoulos
2024-04-08 22:36:01 -07:00
committed by GitHub
parent c68aa3c7c3
commit 1278994b56
2 changed files with 330 additions and 27 deletions

View File

@@ -17,9 +17,9 @@ from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
from .sample_utils import top_p_sampling
# Local imports
from .sample_utils import top_p_sampling
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import apply_lora_layers
from .tuner.utils import dequantize as dequantize_model
@@ -189,7 +189,7 @@ def generate_step(
def generate(
model: nn.Module,
tokenizer: PreTrainedTokenizer,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str,
temp: float = 0.0,
max_tokens: int = 100,
@@ -215,18 +215,18 @@ def generate(
repetition_penalty (float, optional): The penalty factor for repeating tokens.
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty.
"""
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
if verbose:
print("=" * 10)
print("Prompt:", prompt)
prompt_tokens = mx.array(tokenizer.encode(prompt))
detokenizer = tokenizer.detokenizer
tic = time.perf_counter()
tokens = []
token_strings = []
skip = 0
REPLACEMENT_CHAR = "\ufffd"
detokenizer.reset()
for (token, prob), n in zip(
generate_step(
@@ -245,29 +245,21 @@ def generate(
tic = time.perf_counter()
if token == tokenizer.eos_token_id:
break
tokens.append(token)
detokenizer.add_token(token)
if verbose:
s = tokenizer.decode(tokens)
if not s:
continue
elif formatter:
formatter(s[skip:], prob.item())
skip = len(s)
elif s[-1] != REPLACEMENT_CHAR:
print(s[skip:], end="", flush=True)
skip = len(s)
# Reset token cache at line break
if s[-1] == "\n":
tokens = []
token_strings.append(s)
skip = 0
if formatter:
# We have to finalize so that the prob corresponds to the last segment
detokenizer.finalize()
formatter(detokenizer.last_segment, prob.item())
else:
print(detokenizer.last_segment, end="", flush=True)
token_count = n + 1
token_strings.append(tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, ""))
detokenizer.finalize()
if verbose:
print(token_strings[-1][skip:], flush=True)
print(detokenizer.last_segment, flush=True)
gen_time = time.perf_counter() - tic
print("=" * 10)
if token_count == 0:
@@ -278,7 +270,7 @@ def generate(
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
return "".join(token_strings)
return detokenizer.text
def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
@@ -384,8 +376,8 @@ def load(
if adapter_path is not None:
model = apply_lora_layers(model, adapter_path)
model.eval()
tokenizer = load_tokenizer(model_path, tokenizer_config)
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)
return model, tokenizer
@@ -394,7 +386,7 @@ def fetch_from_hub(
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
model = load_model(model_path, lazy)
config = AutoConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer = load_tokenizer(model_path)
return model, config.to_dict(), tokenizer