Merge pull request #52 from ricardo-larosa/mistral_batch_size

Mistral: Pass argument --tokens_per_eval for token generation
This commit is contained in:
Awni Hannun 2023-12-10 11:25:23 -08:00 committed by GitHub
commit 2652b4f055
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -247,6 +247,12 @@ if __name__ == "__main__":
type=float,
default=1.0,
)
parser.add_argument(
"--tokens_per_eval",
help="The batch size of tokens to generate.",
type=int,
default=10,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args()
@ -263,7 +269,7 @@ if __name__ == "__main__":
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
tokens.append(token)
if (len(tokens) % 10) == 0:
if (len(tokens) % args.tokens_per_eval) == 0:
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s, end="", flush=True)