diff --git a/phi2/phi2.py b/phi2/phi2.py index 4a9ed30e..78885eb8 100644 --- a/phi2/phi2.py +++ b/phi2/phi2.py @@ -201,7 +201,7 @@ if __name__ == "__main__": print(args.prompt, end="", flush=True) tokens = [] - for token, _ in zip(generate(prompt, model), range(args.max_tokens)): + for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)): tokens.append(token) if (len(tokens) % 10) == 0: