mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-15 09:48:54 +08:00
fix encoding with special tokens + chat template (#1189)
This commit is contained in:
@@ -353,9 +353,13 @@ def stream_generate(
|
||||
tokenizer = TokenizerWrapper(tokenizer)
|
||||
|
||||
if not isinstance(prompt, mx.array):
|
||||
prompt = mx.array(
|
||||
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
|
||||
)
|
||||
if isinstance(prompt, str):
|
||||
# Try to infer if special tokens are needed
|
||||
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
|
||||
tokenizer.bos_token
|
||||
)
|
||||
prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
|
||||
prompt = mx.array(prompt)
|
||||
|
||||
detokenizer = tokenizer.detokenizer
|
||||
|
||||
@@ -401,7 +405,7 @@ def stream_generate(
|
||||
def generate(
|
||||
model: nn.Module,
|
||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||
prompt: str,
|
||||
prompt: Union[str, List[int]],
|
||||
verbose: bool = False,
|
||||
formatter: Optional[Callable] = None,
|
||||
**kwargs,
|
||||
@@ -412,7 +416,7 @@ def generate(
|
||||
Args:
|
||||
model (nn.Module): The language model.
|
||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||
prompt (str): The string prompt.
|
||||
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
|
||||
verbose (bool): If ``True``, print tokens and timing information.
|
||||
Default: ``False``.
|
||||
kwargs: The remaining options get passed to :func:`stream_generate`.
|
||||
@@ -425,7 +429,6 @@ def generate(
|
||||
)
|
||||
if verbose:
|
||||
print("=" * 10)
|
||||
print("Prompt:", prompt)
|
||||
|
||||
text = ""
|
||||
for response in stream_generate(model, tokenizer, prompt, **kwargs):
|
||||
@@ -654,10 +657,10 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
||||
|
||||
prompt="hello"
|
||||
|
||||
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
|
||||
if tokenizer.chat_template is not None:
|
||||
messages = [{{"role": "user", "content": prompt}}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
messages, add_generation_prompt=True
|
||||
)
|
||||
|
||||
response = generate(model, tokenizer, prompt=prompt, verbose=True)
|
||||
|
||||
Reference in New Issue
Block a user