Add arg tokens_per_eval for token generation

This commit is contained in:
ricardo-larosa 2023-12-09 19:43:44 +01:00
parent 0bf5d0e3bc
commit ce9ba916a3

View File

@ -247,6 +247,12 @@ if __name__ == "__main__":
type=float, type=float,
default=1.0, default=1.0,
) )
parser.add_argument(
"--tokens_per_eval",
help="The batch size of tokens to generate.",
type=int,
default=10,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args() args = parser.parse_args()
@ -263,7 +269,7 @@ if __name__ == "__main__":
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)): for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
tokens.append(token) tokens.append(token)
if (len(tokens) % 10) == 0: if (len(tokens) % args.tokens_per_eval) == 0:
mx.eval(tokens) mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens]) s = tokenizer.decode([t.item() for t in tokens])
print(s, end="", flush=True) print(s, end="", flush=True)