diff --git a/phi2/phi2.py b/phi2/phi2.py index 7973c33d..2da78378 100644 --- a/phi2/phi2.py +++ b/phi2/phi2.py @@ -200,8 +200,12 @@ if __name__ == "__main__": print("[INFO] Generating with Phi-2...", flush=True) print(args.prompt, end="", flush=True) + end_of_response = tokenizer.encode("<|endoftext|>")[0] + tokens = [] for token, _ in zip(generate(prompt, model), range(args.max_tokens)): + if token.item() == end_of_response: + break tokens.append(token) if (len(tokens) % 10) == 0: