Pass few shot file name to --few-shot arg(#141)

This commit is contained in:
Daniel Strobusch 2023-12-18 22:30:04 +01:00 committed by GitHub
parent 517f5808fc
commit 1d62b3ecc1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -318,7 +318,8 @@ def few_shot_generate(args):
s = tokenizer.decode([t.item() for t in tokens]) s = tokenizer.decode([t.item() for t in tokens])
print(s[skip:], end="", flush=True) print(s[skip:], end="", flush=True)
prompt = open(args.prompt).read().strip() print("[INFO] Loading few-shot examples from: {}".format(args.few_shot))
prompt = open(args.few_shot).read().strip()
while True: while True:
question = input("Ask a question: ") question = input("Ask a question: ")
generate(prompt.replace("{}", question)) generate(prompt.replace("{}", question))
@ -358,12 +359,11 @@ if __name__ == "__main__":
parser.add_argument("tokenizer", help="The sentencepiece tokenizer") parser.add_argument("tokenizer", help="The sentencepiece tokenizer")
parser.add_argument( parser.add_argument(
"--prompt", "--prompt",
help="The message to be processed by the model", help="The message to be processed by the model. Ignored when --few-shot is provided.",
default="In the beginning the Universe was created.", default="In the beginning the Universe was created.",
) )
parser.add_argument( parser.add_argument(
"--few-shot", "--few-shot",
action="store_true",
help="Read a few shot prompt from a file (as in `sample_prompt.txt`).", help="Read a few shot prompt from a file (as in `sample_prompt.txt`).",
) )
parser.add_argument( parser.add_argument(