This commit is contained in:
Awni Hannun
2023-12-05 11:24:30 -08:00
parent b7840a4721
commit 234a5f5cfe
2 changed files with 12 additions and 6 deletions

View File

@@ -253,6 +253,12 @@ if __name__ == "__main__":
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=1.0,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args()
@@ -266,7 +272,7 @@ if __name__ == "__main__":
print(args.prompt, end="", flush=True)
prompt = mx.array(tokenizer.encode(args.prompt))
tokens = []
for token, _ in zip(generate(prompt, model), range(args.max_tokens)):
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
tokens.append(token)
if (len(tokens) % 10) == 0: