mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
add an option to apply the tokenizer chat template (#338)
* add an option to apply the tokenizer chat template * fix the option to apply the tokenizer chat template * better error messages for chat template issues * apply the chat template by default when possible * nit in comment' * rebase --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
8022083979
commit
42672f5446
@ -46,6 +46,11 @@ def setup_arg_parser():
|
|||||||
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
|
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
|
||||||
)
|
)
|
||||||
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
|
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ignore-chat-template",
|
||||||
|
action="store_true",
|
||||||
|
help="Use the raw prompt without the tokenizer's chat template.",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -58,9 +63,21 @@ def main(args):
|
|||||||
tokenizer_config["eos_token"] = args.eos_token
|
tokenizer_config["eos_token"] = args.eos_token
|
||||||
|
|
||||||
model, tokenizer = load(args.model, tokenizer_config=tokenizer_config)
|
model, tokenizer = load(args.model, tokenizer_config=tokenizer_config)
|
||||||
|
|
||||||
|
if not args.ignore_chat_template and (
|
||||||
|
hasattr(tokenizer, "apply_chat_template")
|
||||||
|
and tokenizer.chat_template is not None
|
||||||
|
):
|
||||||
|
messages = [{"role": "user", "content": args.prompt}]
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = args.prompt
|
||||||
|
|
||||||
print("=" * 10)
|
print("=" * 10)
|
||||||
print("Prompt:", args.prompt)
|
print("Prompt:", prompt)
|
||||||
prompt = tokenizer.encode(args.prompt)
|
prompt = tokenizer.encode(prompt)
|
||||||
prompt = mx.array(prompt)
|
prompt = mx.array(prompt)
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
tokens = []
|
tokens = []
|
||||||
|
Loading…
Reference in New Issue
Block a user