diff --git a/llama/llama.py b/llama/llama.py index ad6fd8ce..2c1f4d16 100644 --- a/llama/llama.py +++ b/llama/llama.py @@ -318,7 +318,8 @@ def few_shot_generate(args): s = tokenizer.decode([t.item() for t in tokens]) print(s[skip:], end="", flush=True) - prompt = open(args.prompt).read().strip() + print("[INFO] Loading few-shot examples from: {}".format(args.few_shot)) + prompt = open(args.few_shot).read().strip() while True: question = input("Ask a question: ") generate(prompt.replace("{}", question)) @@ -358,12 +359,11 @@ if __name__ == "__main__": parser.add_argument("tokenizer", help="The sentencepiece tokenizer") parser.add_argument( "--prompt", - help="The message to be processed by the model", + help="The message to be processed by the model. Ignored when --few-shot is provided.", default="In the beginning the Universe was created.", ) parser.add_argument( "--few-shot", - action="store_true", help="Read a few shot prompt from a file (as in `sample_prompt.txt`).", ) parser.add_argument(