fix encoding with special tokens + chat template (#1189)

This commit is contained in:
Awni Hannun
2025-01-03 10:50:59 -08:00
committed by GitHub
parent 3a58c36109
commit c4833a2f55
13 changed files with 95 additions and 97 deletions

View File

@@ -353,9 +353,13 @@ def stream_generate(
tokenizer = TokenizerWrapper(tokenizer)
if not isinstance(prompt, mx.array):
prompt = mx.array(
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
)
if isinstance(prompt, str):
# Try to infer if special tokens are needed
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
tokenizer.bos_token
)
prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
prompt = mx.array(prompt)
detokenizer = tokenizer.detokenizer
@@ -401,7 +405,7 @@ def stream_generate(
def generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str,
prompt: Union[str, List[int]],
verbose: bool = False,
formatter: Optional[Callable] = None,
**kwargs,
@@ -412,7 +416,7 @@ def generate(
Args:
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt.
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
kwargs: The remaining options get passed to :func:`stream_generate`.
@@ -425,7 +429,6 @@ def generate(
)
if verbose:
print("=" * 10)
print("Prompt:", prompt)
text = ""
for response in stream_generate(model, tokenizer, prompt, **kwargs):
@@ -654,10 +657,10 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
prompt="hello"
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
if tokenizer.chat_template is not None:
messages = [{{"role": "user", "content": prompt}}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
messages, add_generation_prompt=True
)
response = generate(model, tokenizer, prompt=prompt, verbose=True)