mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
fix llava (#1149)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import codecs
|
||||
import json
|
||||
import sys
|
||||
|
||||
@@ -188,6 +189,8 @@ def main():
|
||||
elif using_cache:
|
||||
tokenizer.chat_template = metadata["chat_template"]
|
||||
|
||||
prompt = codecs.decode(args.prompt, "unicode_escape")
|
||||
|
||||
if not args.ignore_chat_template and (
|
||||
hasattr(tokenizer, "apply_chat_template")
|
||||
and tokenizer.chat_template is not None
|
||||
@@ -199,7 +202,7 @@ def main():
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": sys.stdin.read() if args.prompt == "-" else args.prompt,
|
||||
"content": sys.stdin.read() if prompt == "-" else prompt,
|
||||
}
|
||||
)
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
@@ -216,8 +219,6 @@ def main():
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
prompt = prompt[test_prompt.index("<query>") :]
|
||||
else:
|
||||
prompt = args.prompt
|
||||
|
||||
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
||||
response = generate(
|
||||
|
Reference in New Issue
Block a user