mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
default trust remote code for tokenizer, allow system prompt to be configurable
This commit is contained in:
parent
53569da120
commit
0593aaea89
@ -41,17 +41,17 @@ def setup_arg_parser():
|
|||||||
type=str,
|
type=str,
|
||||||
help="Optional path for the trained adapter weights and config.",
|
help="Optional path for the trained adapter weights and config.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--trust-remote-code",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable trusting remote code for tokenizer",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--eos-token",
|
"--eos-token",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="End of sequence token for tokenizer",
|
help="End of sequence token for tokenizer",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--system-prompt",
|
||||||
|
default=None,
|
||||||
|
help="System prompt to be used for the chat template",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prompt",
|
"--prompt",
|
||||||
"-p",
|
"-p",
|
||||||
@ -191,8 +191,7 @@ def main():
|
|||||||
tokenizer_config = (
|
tokenizer_config = (
|
||||||
{} if not using_cache else json.loads(metadata["tokenizer_config"])
|
{} if not using_cache else json.loads(metadata["tokenizer_config"])
|
||||||
)
|
)
|
||||||
if args.trust_remote_code:
|
tokenizer_config["trust_remote_code"] = True
|
||||||
tokenizer_config["trust_remote_code"] = True
|
|
||||||
if args.eos_token is not None:
|
if args.eos_token is not None:
|
||||||
tokenizer_config["eos_token"] = args.eos_token
|
tokenizer_config["eos_token"] = args.eos_token
|
||||||
|
|
||||||
@ -224,12 +223,16 @@ def main():
|
|||||||
hasattr(tokenizer, "apply_chat_template")
|
hasattr(tokenizer, "apply_chat_template")
|
||||||
and tokenizer.chat_template is not None
|
and tokenizer.chat_template is not None
|
||||||
):
|
):
|
||||||
messages = [
|
if args.system_prompt is not None:
|
||||||
|
messages = [{"role": "system", "content": args.system_prompt}]
|
||||||
|
else:
|
||||||
|
messages = []
|
||||||
|
messages.append(
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": sys.stdin.read() if args.prompt == "-" else args.prompt,
|
"content": sys.stdin.read() if args.prompt == "-" else args.prompt,
|
||||||
}
|
}
|
||||||
]
|
)
|
||||||
prompt = tokenizer.apply_chat_template(
|
prompt = tokenizer.apply_chat_template(
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
)
|
)
|
||||||
@ -237,8 +240,9 @@ def main():
|
|||||||
# Treat the prompt as a suffix assuming that the prefix is in the
|
# Treat the prompt as a suffix assuming that the prefix is in the
|
||||||
# stored kv cache.
|
# stored kv cache.
|
||||||
if using_cache:
|
if using_cache:
|
||||||
|
messages[-1]["content"] = "<query>"
|
||||||
test_prompt = tokenizer.apply_chat_template(
|
test_prompt = tokenizer.apply_chat_template(
|
||||||
[{"role": "user", "content": "<query>"}],
|
messages,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user