Pass along temp argument to generate()

This commit is contained in:
Sam Coward 2023-12-15 15:16:41 -05:00
parent 126b99d8be
commit 877f88dfea

View File

@ -201,7 +201,7 @@ if __name__ == "__main__":
print(args.prompt, end="", flush=True) print(args.prompt, end="", flush=True)
tokens = [] tokens = []
for token, _ in zip(generate(prompt, model), range(args.max_tokens)): for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
tokens.append(token) tokens.append(token)
if (len(tokens) % 10) == 0: if (len(tokens) % 10) == 0: