mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-15 09:48:54 +08:00
fix encoding with special tokens + chat template (#1189)
This commit is contained in:
@@ -110,29 +110,17 @@ def main():
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = tokenizer.default_chat_template
|
||||
|
||||
if not args.ignore_chat_template and (
|
||||
hasattr(tokenizer, "apply_chat_template")
|
||||
and tokenizer.chat_template is not None
|
||||
):
|
||||
if not args.ignore_chat_template and tokenizer.chat_template is not None:
|
||||
messages = [{"role": "user", "content": args.prompt}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
messages, add_generation_prompt=False, continue_final_message=True
|
||||
)
|
||||
|
||||
# Treat the prompt as a prefix assuming that the suffix will be
|
||||
# provided at generation time.
|
||||
test_prompt = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": "<query>"}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
n = len(test_prompt) - test_prompt.index("<query>") - len("<query>")
|
||||
prompt = prompt[:-n]
|
||||
else:
|
||||
prompt = args.prompt
|
||||
prompt = tokenizer.encode(args.prompt)
|
||||
|
||||
cache = make_prompt_cache(model, args.max_kv_size)
|
||||
y = mx.array(tokenizer.encode(prompt))
|
||||
y = mx.array(prompt)
|
||||
|
||||
# Process the prompt
|
||||
start = time.time()
|
||||
|
||||
Reference in New Issue
Block a user